diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 78ae512eb54202..e57c2b30808d82 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -10,6 +10,10 @@ package( licenses = ["notice"], ) +exports_files(glob([ + "testdata/*.bin", +])) + package_group( name = "friends", packages = [ @@ -123,39 +127,39 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -163,12 +167,12 @@ tf_cc_test( ], deps = [ ":quantize_model", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -181,13 +185,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ # TODO(b/327796566): re-enable after the bug is fixed @@ -200,15 +204,28 @@ tf_cc_test( ], deps = [ ":quantize_weights", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", "@local_tsl//tsl/platform:logging", ], ) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/core/api", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index e7d5e00b703392..1e7cdcdea07d33 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/lib/core/status_test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -192,7 +192,8 @@ void VerifyQuantizationScale( class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() { - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -277,7 +278,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, protected: QuantizeConvModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); // Flatbuffer is missing calibration data -- add dummy params. @@ -347,7 +349,7 @@ TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() { - input_model_ = ReadModel(internal::kConvModelWithNoBias); + input_model_ = ReadModel(::mlir::lite::internal::kConvModelWithNoBias); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -367,7 +369,7 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() { - input_model_ = ReadModel(internal::kModelSplit); + input_model_ = ReadModel(::mlir::lite::internal::kModelSplit); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -452,7 +454,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, protected: QuantizeConvModel2Test() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); auto& subgraph = model_.subgraphs[0]; @@ -690,7 +693,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() { - input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -753,7 +757,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() { - input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -813,7 +818,7 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() { - input_model_ = ReadModel(internal::kMultiInputAddWithReshape); + input_model_ = ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -933,7 +938,7 @@ class QuantizeConstInputTest : public QuantizeModelTest, protected: QuantizeConstInputTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConstInputAddModel); + input_model_ = ReadModel(::mlir::lite::internal::kConstInputAddModel); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -980,7 +985,7 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() { - input_model_ = ReadModel(internal::kModelWithArgMaxOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithArgMaxOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1025,7 +1030,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() { - input_model_ = ReadModel(internal::kLstmCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1037,7 +1042,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { /*allow_float=*/true, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1048,7 +1053,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() { - input_model_ = ReadModel(internal::kLstmCalibrated2); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated2); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1061,7 +1066,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1072,7 +1077,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() { - input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated); + input_model_ = ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1086,7 +1092,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1097,7 +1103,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() { - input_model_ = ReadModel(internal::kSvdfCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kSvdfCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1110,7 +1116,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1123,7 +1129,7 @@ class QuantizeFCTest : public QuantizeModelTest, protected: QuantizeFCTest() { disable_per_channel_quantization_for_dense_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithFCOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1371,7 +1377,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() { - input_model_ = ReadModel(internal::kModelMixed); + input_model_ = ReadModel(::mlir::lite::internal::kModelMixed); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1409,7 +1415,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizePackTest : public QuantizeModelTest { protected: QuantizePackTest() { - input_model_ = ReadModel(internal::kModelPack); + input_model_ = ReadModel(::mlir::lite::internal::kModelPack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1526,14 +1532,15 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { Eq(input2->quantization->zero_point)); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() { - input_model_ = ReadModel(internal::kModelWithUnpack); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithUnpack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1583,7 +1590,7 @@ class QuantizeBroadcastToModelTest protected: QuantizeBroadcastToModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithBroadcastToOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1646,7 +1653,7 @@ class QuantizeGatherNDModelTest protected: QuantizeGatherNDModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithGatherNDOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithGatherNDOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1706,7 +1713,7 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() { - input_model_ = ReadModel(internal::kModelWithWhereOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithWhereOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 2e80bcae7486b4..7a42e74c2619af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/platform/logging.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -59,25 +60,25 @@ std::unique_ptr CreateMutableModelFromFile(const Model* input_model) { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc similarity index 95% rename from tensorflow/lite/tools/optimize/test_util.cc rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc index 5ca45326d1dcad..e096868eec8807 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/tools/optimize/test_util.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { const char* kConvModelWithMinus128Plus127Weights = "single_conv_weights_min_minus_127_max_plus_127.bin"; @@ -89,5 +89,5 @@ int FailOnErrorReporter::Report(const char* format, va_list args) { return 0; } } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h similarity index 93% rename from tensorflow/lite/tools/optimize/test_util.h rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.h index 11e7ef230910f2..b4e317c131888e 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ -#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ #include "tensorflow/lite/core/api/error_reporter.h" -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { // Test model with a single convolution. // Floating point weights of the model are all integers and lie in @@ -132,12 +132,12 @@ extern const char* kQatModelWithFc; extern const char* kModelWithResourceVarsCalibrated; // An error reporter that fails on testing. -class FailOnErrorReporter : public ErrorReporter { +class FailOnErrorReporter : public tflite::ErrorReporter { public: int Report(const char* format, va_list args) override; }; } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir -#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ diff --git a/tensorflow/lite/tools/optimize/testdata/README.md b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/README.md rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md diff --git a/tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin diff --git a/tensorflow/lite/tools/optimize/testdata/argmax.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/argmax.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin diff --git a/tensorflow/lite/tools/optimize/testdata/broadcast_to.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/broadcast_to.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin diff --git a/tensorflow/lite/tools/optimize/testdata/concat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/concat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/custom_op.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/custom_op.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc_qat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc_qat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/gather_nd.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/gather_nd.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/maximum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/maximum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/minimum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/minimum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed16x8.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed16x8.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin diff --git a/tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin diff --git a/tensorflow/lite/tools/optimize/testdata/pack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/pack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin diff --git a/tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/split.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/split.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/transpose.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/transpose.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unpack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unpack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin diff --git a/tensorflow/lite/tools/optimize/testdata/where.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/where.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 711f97bdfddd16..a05a5cbdb10710 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -14,10 +14,6 @@ package( licenses = ["notice"], ) -exports_files(glob([ - "testdata/*.bin", -])) - cc_library( name = "reduced_precision_support", srcs = [], @@ -39,7 +35,6 @@ tf_cc_test( ], deps = [ ":reduced_precision_support", - ":test_util", "//tensorflow/core/platform:platform_port", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", @@ -223,7 +218,6 @@ tf_cc_test( ], deps = [ ":model_utils", - ":test_util", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", "//tensorflow/lite/schema:schema_fbs", @@ -250,10 +244,10 @@ tf_cc_test( name = "quantization_utils_test", srcs = ["quantization_utils_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - ":testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", ], tags = [ "tflite_not_portable_android", @@ -261,7 +255,7 @@ tf_cc_test( ], deps = [ ":quantization_utils", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -316,13 +310,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ "tflite_not_portable_android", @@ -330,7 +324,7 @@ tf_cc_test( ], deps = [ ":quantize_weights", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -342,19 +336,6 @@ tf_cc_test( ], ) -cc_library( - name = "test_util", - testonly = 1, - srcs = ["test_util.cc"], - hdrs = ["test_util.h"], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/core/api", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - cc_library( name = "quantize_model", srcs = ["quantize_model.cc"], @@ -379,40 +360,40 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/resource_vars_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/resource_vars_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -420,7 +401,7 @@ tf_cc_test( ], deps = [ ":quantize_model", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", diff --git a/tensorflow/lite/tools/optimize/model_utils_test.cc b/tensorflow/lite/tools/optimize/model_utils_test.cc index 65e3afe35e2da2..f702e1fa0a0ddd 100644 --- a/tensorflow/lite/tools/optimize/model_utils_test.cc +++ b/tensorflow/lite/tools/optimize/model_utils_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index a09acef6f4aa3c..a0ab9c43eacb75 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -46,7 +46,7 @@ std::unique_ptr ReadModel(const char* model) { } std::unique_ptr ReadConvModel() { - return ReadModel(internal::kConvModelWith0Plus10Weights); + return ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights); } using ::testing::ElementsAreArray; diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 681507c8e0d31d..a7e9115f8bdaaa 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc @@ -78,7 +78,8 @@ TensorType GetBiasTensorType(TensorType& activation_type) { class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) {} + : QuantizeModelTest( + ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights)) {} explicit QuantizeModelTest(std::unique_ptr input_model) { input_model_ = std::move(input_model); @@ -91,7 +92,7 @@ class QuantizeModelTest : public testing::Test { const Model* readonly_model_; tflite::ModelT model_; flatbuffers::FlatBufferBuilder builder_; - internal::FailOnErrorReporter error_reporter_; + ::mlir::lite::internal::FailOnErrorReporter error_reporter_; }; void ExpectSameModels(const ModelT& model, const ModelT& expected_model) { @@ -136,7 +137,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} TensorType tensor_type_; @@ -405,7 +407,8 @@ TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWithNoBias)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWithNoBias)) {} }; TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { @@ -422,7 +425,8 @@ class QuantizeConcatModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConcatModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) {} void SetUp() override { tensor_type_ = GetParam(); @@ -536,7 +540,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest, class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() - : QuantizeModelTest(ReadModel(internal::kModelSplit)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelSplit)) {} }; // There are two outputs for split with different scales, the resulting model @@ -601,8 +605,8 @@ TEST_F(QuantizeSplitModelTest, QuantizeSplit) { class QuantizeConvModel1Test : public QuantizeModelTest { protected: QuantizeConvModel1Test() - : QuantizeModelTest( - ReadModel(internal::kConvModelWithMinus128Plus127Weights)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kConvModelWithMinus128Plus127Weights)) {} }; TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { @@ -703,7 +707,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModel2Test() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -925,8 +930,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() - : QuantizeModelTest( - ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { @@ -985,8 +990,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() - : QuantizeModelTest( - ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { @@ -1045,7 +1050,8 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() - : QuantizeModelTest(ReadModel(internal::kMultiInputAddWithReshape)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape)) {} }; TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { @@ -1155,7 +1161,8 @@ class QuantizeConstInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConstInputTest() - : QuantizeModelTest(ReadModel(internal::kConstInputAddModel)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConstInputAddModel)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1213,7 +1220,8 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() - : QuantizeModelTest(ReadModel(internal::kModelWithArgMaxOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithArgMaxOp)) {} }; TEST_F(QuantizeArgMaxTest, VerifyArgMax) { @@ -1254,7 +1262,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated)) {} }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { @@ -1265,7 +1273,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1276,7 +1284,8 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated2)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated2)) { + } }; TEST_F(QuantizeLSTM2Test, VerifyLSTM) { @@ -1287,7 +1296,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1298,8 +1307,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() - : QuantizeModelTest( - ReadModel(internal::kUnidirectionalSequenceLstmCalibrated)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated)) {} }; TEST_F(QuantizeUnidirectionalSequenceLSTMTest, @@ -1312,7 +1321,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1323,7 +1332,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() - : QuantizeModelTest(ReadModel(internal::kSvdfCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kSvdfCalibrated)) {} }; TEST_F(QuantizeSVDFTest, VerifySVDF) { @@ -1334,7 +1343,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1379,7 +1388,8 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { class QuantizeFCTest : public QuantizeModelTest { protected: - QuantizeFCTest() : QuantizeModelTest(ReadModel(internal::kModelWithFCOp)) {} + QuantizeFCTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithFCOp)) {} }; TEST_F(QuantizeFCTest, VerifyFC) { @@ -1430,7 +1440,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() - : QuantizeModelTest(ReadModel(internal::kModelMixed)), + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1471,7 +1481,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizeOp16x8Test : public QuantizeModelTest { protected: QuantizeOp16x8Test() - : QuantizeModelTest(ReadModel(internal::kModelMixed16x8)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed16x8)) {} }; TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { @@ -1502,7 +1512,8 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { class QuantizePackTest : public QuantizeModelTest { protected: - QuantizePackTest() : QuantizeModelTest(ReadModel(internal::kModelPack)) {} + QuantizePackTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelPack)) {} }; TEST_F(QuantizePackTest, VerifyPack) { @@ -1628,14 +1639,16 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { EXPECT_EQ(subgraph->tensors[5]->name, "output"); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() - : QuantizeModelTest(ReadModel(internal::kModelWithUnpack)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithUnpack)) { + } }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { auto status = QuantizeModel(&builder_, &model_, &error_reporter_); @@ -1680,7 +1693,8 @@ TEST_F(QuantizeUnpackTest, VerifyUnpack) { class QuantizeTransposeTest : public QuantizeModelTest { protected: QuantizeTransposeTest() - : QuantizeModelTest(ReadModel(internal::kModelWithTranspose)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithTranspose)) {} }; TEST_F(QuantizeTransposeTest, VerifyTranspose) { @@ -1720,7 +1734,8 @@ TEST_F(QuantizeTransposeTest, VerifyTranspose) { class QuantizeQatTest : public QuantizeModelTest { protected: - QuantizeQatTest() : QuantizeModelTest(ReadModel(internal::kQatModelWithFc)) {} + QuantizeQatTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kQatModelWithFc)) {} }; TEST_F(QuantizeQatTest, VerifySingleQuantize) { @@ -1777,7 +1792,8 @@ class QuantizeBroadcastToModelTest public testing::WithParamInterface { protected: QuantizeBroadcastToModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithBroadcastToOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1844,7 +1860,8 @@ class QuantizeGatherNDModelTest public testing::WithParamInterface { protected: QuantizeGatherNDModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithGatherNDOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithGatherNDOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1906,7 +1923,8 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithWhereOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithWhereOp)) {} }; TEST_F(QuantizeWhereModelTest, QuantizeWhere) { @@ -1976,8 +1994,8 @@ class QuantizeResourcesModelTest public testing::WithParamInterface { protected: QuantizeResourcesModelTest() - : QuantizeModelTest( - ReadModel(internal::kModelWithResourceVarsCalibrated)) { + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kModelWithResourceVarsCalibrated)) { TestType obj = GetParam(); tensor_type_ = obj.tensor_type; modify_range_ = obj.modify_range; @@ -2119,7 +2137,8 @@ class QuantizeConcatConstModelTest public testing::WithParamInterface { protected: QuantizeConcatConstModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) { // Make one of the values constant. MakeInputConstant(&model_); } @@ -2224,7 +2243,8 @@ class BiasInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: BiasInputTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)) { BiasTestType obj = GetParam(); tensor_type_ = obj.tensor_type; bias_type_ = obj.bias_type; diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index 0e9c3efc17acd9..b2279ed34908f6 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -40,25 +40,25 @@ namespace { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc index 6b5cf538b50c43..19400079b17e96 100644 --- a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc +++ b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 553d29e960458e..a4fc2041629600 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -3057,11 +3057,14 @@ cc_library( srcs = ["all_reduce_simplifier.cc"], hdrs = ["all_reduce_simplifier.h"], deps = [ + ":collective_ops_utils", + ":hlo_module_config", ":hlo_pass", ":hlo_replication_analysis", "//xla:literal_util", "//xla:shape_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", @@ -3076,6 +3079,7 @@ xla_cc_test( srcs = ["all_reduce_simplifier_test.cc"], deps = [ ":all_reduce_simplifier", + ":hlo_module_config", ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", @@ -7254,6 +7258,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index dd13ffebd9e852..641fedf0c72405 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { select->mutable_operand(0)->shape(), HloOpcode::kNot, select->mutable_operand(0))); } + // select(compare(a, b, GT/GE), a, b) => or(a, b) + // select(compare(a, b, LT/LE), a, b) => and(a, b) + // select(compare(a, b, EQ), a, b) => b + // select(compare(a, b, NE), a, b) => a + HloInstruction *compare, *lhs, *rhs; + if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) && + Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) { + auto cmp_dir = compare->comparison_direction(); + if (cmp_dir == ComparisonDirection::kGt || + cmp_dir == ComparisonDirection::kGe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kOr, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kLt || + cmp_dir == ComparisonDirection::kLe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kAnd, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kEq) { + return ReplaceInstruction(select, rhs); + } + if (cmp_dir == ComparisonDirection::kNe) { + return ReplaceInstruction(select, lhs); + } + } } // select(pred, xs, dynamic_update_slice(xs, x, i)) diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 00970a51546b1a..921098aa7565e8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) { GmockMatch(m::Not(m::Parameter(0)))); } +// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectGtCompare) { + for (const auto cmp_dir : {"GT", "GE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Or(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectLtCompare) { + for (const auto cmp_dir : {"LT", "LE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::And(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectEqCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=EQ + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + +// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectNeCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +// select(compare(a, b, NE), b, a) ≠> a - wrong operands order +TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p1, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified // to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i) TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) { diff --git a/third_party/xla/xla/service/all_reduce_simplifier.cc b/third_party/xla/xla/service/all_reduce_simplifier.cc index 5837ea49da0aae..0760433bda4489 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier.cc @@ -19,14 +19,18 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" +#include "xla/service/collective_ops_utils.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_replication_analysis.h" #include "xla/shape_util.h" #include "tsl/platform/errors.h" @@ -42,22 +46,33 @@ absl::StatusOr AllReduceSimplifier::Run( HloReplicationAnalysis::Run(module, /*cross_partition_spmd=*/false)); std::vector> all_reduces_to_replace; - // Returns the size of a replica group if all groups have the same size, or -1 - // if they have different sizes. - auto get_replica_group_size = - [this](const HloInstruction* all_reduce) -> int64_t { - if (all_reduce->replica_groups().empty()) { - return replica_count_; + // Returns the number of participants in a replica group if all groups have + // the same size, or -1 if they have different sizes. + // Number of participants depends on the mode of the collective operation. + auto get_participant_counts_for_replica_group = + [](const HloInstruction* all_reduce) -> absl::StatusOr { + const HloModuleConfig& config = all_reduce->GetModule()->config(); + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode(all_reduce->channel_id().has_value(), + Cast(all_reduce) + ->use_global_device_ids())); + + int64_t num_devices = config.num_partitions(); + int64_t num_replicas = config.replica_count(); + TF_ASSIGN_OR_RETURN(std::vector participant_counts, + GetPariticipantCountsForReplicaGroups( + num_replicas, num_devices, + all_reduce->replica_groups(), group_mode)); + if (participant_counts.empty()) { + return -1; } - int64_t replica_group_size = -1; - for (const auto& group : all_reduce->replica_groups()) { - if (replica_group_size == -1) { - replica_group_size = group.replica_ids_size(); - } else if (replica_group_size != group.replica_ids_size()) { - return -1; - } + if (!absl::c_all_of(participant_counts, [&](int64_t participant_count) { + return participant_count == participant_counts[0]; + })) { + return -1; } - return replica_group_size; + return participant_counts[0]; }; bool changed = false; @@ -83,11 +98,24 @@ absl::StatusOr AllReduceSimplifier::Run( // optimize out (being fed within a tuple input). continue; } - if (!inst->IsCrossReplicaAllReduce()) { + if (!inst->IsCrossReplicaAllReduce() && !inst->IsCrossModuleAllReduce()) { continue; } - int64_t group_size = get_replica_group_size(inst); - if (group_size == -1) { + TF_ASSIGN_OR_RETURN(int64_t group_size, + get_participant_counts_for_replica_group(inst)); + + // We will not simplify this all reduce if any of the following is true: + // 1. All group do not have the same size. + // + // 2. The AllReduce is not cross replica and the group size is not 1. + // Since the replication analysis performed earlier is only for cross + // replica spmd. + // + // 3. The AllReduce is not cross replica and the module is not using spmd. + if (group_size == -1 || + (!inst->IsCrossReplicaAllReduce() && group_size != 1) || + (!inst->IsCrossReplicaAllReduce() && + !module->config().use_spmd_partitioning())) { continue; } if (replication->HloInstructionIsReplicatedAt(inst->operand(0), {}) || diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index 0843fc6df1a87b..e78881a0c19292 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" @@ -191,5 +192,95 @@ test { EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Parameter(0))); } + +TEST_F(AllReduceSimplifierTest, TrivialSubgroupNonCrossReplicaAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + use_global_device_ids=true, + replica_groups={{0},{1},{2},{3},{4},{5},{6},{7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, + /*num_partitions=*/8)); + module->mutable_config().set_use_spmd_partitioning(true); + AllReduceSimplifier simplifier(/*replica_count=*/1); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(AllReduceSimplifierTest, NonCrossReplicaAllReduceAfterAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + use_global_device_ids=true, + replica_groups={{0,2},{1,3},{4,6},{5,7}}, + to_apply=sum + ROOT all-reduce.1 = f32[8,16] all-reduce(all-reduce), + channel_id=2, + use_global_device_ids=true, + replica_groups={{0,4},{1,5},{2,6},{3,7}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, + /*num_partitions=*/8)); + module->mutable_config().set_use_spmd_partitioning(true); + AllReduceSimplifier simplifier(/*replica_count=*/1); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} + +TEST_F(AllReduceSimplifierTest, MPMDNonCrossReplicaAllReduce) { + const char* kModuleStr = R"( +HloModule m + +sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) +} + +test { + p0 = f32[8,16] parameter(0), parameter_replication={false} + ROOT all-reduce = f32[8,16] all-reduce(p0), + channel_id=1, + replica_groups={{0},{1}}, + to_apply=sum +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/2, + /*num_partitions=*/1)); + // Mark as MPMD. + module->mutable_config().set_use_spmd_partitioning(false); + AllReduceSimplifier simplifier(/*replica_count=*/2); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 01c4bba5abfefb..5bd802c343f523 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -515,8 +515,12 @@ absl::StatusOr> GetPariticipantCountsForReplicaGroups( switch (group_mode) { case CollectiveOpGroupMode::kCrossReplica: { - participant_counts.resize(participating_replica_groups.size(), - num_partitions); + for (const auto& replica_group : participating_replica_groups) { + for (int partition_id = 0; partition_id < num_partitions; + ++partition_id) { + participant_counts.push_back(replica_group.replica_ids().size()); + } + } return participant_counts; } case CollectiveOpGroupMode::kCrossPartition: { diff --git a/third_party/xla/xla/service/collective_ops_utils_test.cc b/third_party/xla/xla/service/collective_ops_utils_test.cc index 9fbcaf5f1f2ba2..c71776323f869f 100644 --- a/third_party/xla/xla/service/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/collective_ops_utils_test.cc @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include +#include #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -129,6 +131,21 @@ TEST(CollectiveOpsUtilsTest, CollectiveWithChannelId2) { EXPECT_EQ(IsOrHasCollectiveWithChannelId(fusion2.get()), nullptr); } +// Creates a container of ReplicaGroups. +std::vector CreateReplicaGroups( + const std::vector> &replica_groups) { + std::vector result; + result.reserve(replica_groups.size()); + for (const auto &replica_group : replica_groups) { + ReplicaGroup group; + for (auto id : replica_group) { + group.add_replica_ids(id); + } + result.push_back(group); + } + return result; +} + } // namespace // Tests for GetCollectOpGroupMode @@ -190,7 +207,7 @@ namespace GetParticipatingDevicesTest { // expected output corresponding to those values. struct TestCase { xla::Array2D device_assignment; - std::vector> replica_groups; + std::vector> replica_groups; bool has_channel_id; std::optional use_global_device_ids; @@ -455,15 +472,8 @@ TEST_P(GetParticipatingDevicesTest, Test) { } } - std::vector replica_groups; - absl::c_transform(tc.replica_groups, std::back_inserter(replica_groups), - [](const std::vector &ids) { - ReplicaGroup group; - for (int id : ids) { - group.add_replica_ids(id); - } - return group; - }); + std::vector replica_groups = + CreateReplicaGroups(tc.replica_groups); absl::StatusOr group_mode = GetCollectiveOpGroupMode(tc.has_channel_id, tc.use_global_device_ids); @@ -518,4 +528,77 @@ INSTANTIATE_TEST_SUITE_P(GetParticipatingDevices, GetParticipatingDevicesTest, testing::ValuesIn(GetTestCases())); } // namespace GetParticipatingDevicesTest + +namespace GetPariticipantCountsForReplicaGroupsTest { + +struct TestCase { + std::string test_name; + std::vector> replica_groups; + CollectiveOpGroupMode group_mode; + int64_t num_replicas; + int64_t num_partitions; + std::vector expected; +}; + +class GetPariticipantCountsForReplicaGroupsTest + : public testing::TestWithParam {}; + +TEST_P(GetPariticipantCountsForReplicaGroupsTest, Test) { + const TestCase &tc = GetParam(); + + std::vector replica_groups = + CreateReplicaGroups(tc.replica_groups); + TF_ASSERT_OK_AND_ASSIGN( + std::vector actual, + GetPariticipantCountsForReplicaGroups(tc.num_replicas, tc.num_partitions, + replica_groups, tc.group_mode)); + EXPECT_THAT(actual, testing::ElementsAreArray(tc.expected)); +} + +std::vector GetTestCases() { + return { + { + "CrossReplicaEmptyGroup", + {}, + CollectiveOpGroupMode::kCrossReplica, + 8, + 1, + {8}, + }, + { + "CrossReplicaWithPartitions", + {{0, 1}, {2, 3}}, + CollectiveOpGroupMode::kCrossReplica, + 4, + 2, + {2, 2, 2, 2}, + }, + { + "CrossReplicaAndPartition", + {{0, 1}, {2, 3}}, + CollectiveOpGroupMode::kCrossReplicaAndPartition, + 4, + 2, + {4, 4}, + }, + { + "FlattenedID", + {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}}, + CollectiveOpGroupMode::kFlattenedID, + 4, + 2, + {1, 1, 1, 1, 1, 1, 1, 1}, + }, + }; +} +INSTANTIATE_TEST_SUITE_P( + GetPariticipantCountsForReplicaGroups, + GetPariticipantCountsForReplicaGroupsTest, + testing::ValuesIn(GetTestCases()), + [](const testing::TestParamInfo< + GetPariticipantCountsForReplicaGroupsTest::ParamType> &info) { + return info.param.test_name; + }); + +} // namespace GetPariticipantCountsForReplicaGroupsTest } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc index ce1dfa4f1c7863..8f5e26124f24a4 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc @@ -378,8 +378,16 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } +static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { + const HloInstruction* loop_tuple = while_loop->operand(0); + const Shape& tuple_shape = loop_tuple->shape(); + CHECK(tuple_shape.IsTuple()); + return tuple_shape.tuple_shapes_size(); +} + absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { + GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, + HloInstruction* ag_with_shared_operand) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -392,15 +400,61 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( } // Get the output operand of the full buffer. HloInstruction* root = while_body->root_instruction(); + // Change loop body to include the new input and output element. + HloInstruction* input_tuple = while_body->parameter_instruction(0); + const Shape& input_shape = input_tuple->shape(); // The full buffer that we will use to cache the accumulated activation - // is the 4th operand in the output tuple. - int64_t full_cache_buffer_index = 3; + // is the last operand in the output tuple. + int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); + std::vector new_input_shapes(input_shape.tuple_shapes().begin(), + input_shape.tuple_shapes().end()); + new_input_shapes.push_back(ag_with_shared_operand->shape()); + // Update body input shape + Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); + *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = - root->mutable_operand(full_cache_buffer_index); - HloInstruction* new_full_buffer_output; + while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + ag_with_shared_operand->shape(), input_tuple, + full_cache_buffer_index)); + + // Update condition input shape + HloComputation* cond_comp = loop->while_condition(); + HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); + *cond_input_tuple->mutable_shape() = new_input_shape; + + // Update input to the while instruction in parent computation + HloInstruction* original_while_input = loop->mutable_operand(0); + HloComputation* parent_comp = loop->parent(); + std::vector new_operands( + original_while_input->operands().begin(), + original_while_input->operands().end()); + new_operands.push_back( + parent_comp->AddInstruction(HloInstruction::CreateBroadcast( + ag_with_shared_operand->shape(), + parent_comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(new_input_shapes[0].element_type()))), + {}))); + HloInstruction* new_while_input = + parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + loop->ReplaceOperandWithDifferentShape(0, new_while_input)); + TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( + original_while_input, new_while_input)); + *loop->mutable_shape() = new_input_shape; + + HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices // This should just be a constant(0) HloInstruction* dus_boundary_constant; + // The slice we need this time is the output of the first + // collective-permute + HloInstruction* first_cp_output; + for (HloInstruction* gte_user : input_gte->users()) { + if (gte_user->opcode() == HloOpcode::kCollectivePermute) { + first_cp_output = gte_user; + break; + } + } for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) { HloInstruction* slice_indices; // If we have a DUS(PARAM,DS) pattern, we need to update the output @@ -434,24 +488,68 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( dus_boundary_constant->shape(), slice_indices)); VLOG(5) << "Created slice op for second slice: " << slice_indices->ToString(); - // The slice we need this time is the output of the first - // collective-permute - HloInstruction* cp_output; - for (HloInstruction* gte_user : input_gte->users()) { - if (gte_user->opcode() == HloOpcode::kCollectivePermute) { - cp_output = gte_user; - break; - } - } new_full_buffer_output = while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_buffer_output_gte->shape(), full_buffer_output_gte, - cp_output, + first_cp_output, {dus_boundary_constant, slice_indices, dus_boundary_constant})); } + + // If we have a Dot(DS(parameter_index1)), then operands are sharded along + // the contracting dim. Slice indices will be the contracting dim's slices. + HloInstruction* slice_index; + HloInstruction* ds_index_constant; + HloInstruction* remainder; + HloInstruction* ds_param; + // There will be 2 dynamic-slices for unrolled loops, match for each one to + // get the slice index which will be used to write the corresponding + // received shard into cached activation buffer. For unrolled loops, we need + // to write to the final buffer twice per iteration, so we need to match for + // the correct slice index based on each DS. + if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) && + Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) { + for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size(); + ds_op_i++) { + if (!Match( + ds_param->mutable_operand(ds_op_i), + m::Reshape(&slice_index, m::DynamicSlice(m::Constant(), + m::Op(&remainder)))) && + !Match(ds_param->mutable_operand(ds_op_i), + m::Constant(&ds_index_constant))) { + return absl::OkStatus(); + } + } + // First DS has slice index calculated based on loop iterator + // Remainder(add(gte, partition_id)) + if (Match(remainder, + m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) { + full_buffer_output_gte = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + input_gte, + {ds_index_constant, ds_index_constant, slice_index})); + } + // Second DS has slice index calculated based on loop iterator+1 hence + // Remainder(add(add(gte, 1), partition_id)) + if (Match(remainder, + m::Remainder( + m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()), + m::Op()))) { + new_full_buffer_output = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + first_cp_output, + {ds_index_constant, ds_index_constant, slice_index})); + } + } } - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index, - new_full_buffer_output)); + std::vector original_operands(root->operands().begin(), + root->operands().end()); + original_operands.push_back(new_full_buffer_output); + HloInstruction* new_output_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(original_operands)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); return absl::OkStatus(); } @@ -620,17 +718,20 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); + + if (!ag_loop.consumed) { + TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( + ag_loop, ag_with_shared_operand)); + ag_loop.consumed = true; + } int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloInstruction* new_gte = comp->AddInstruction( - HloInstruction::CreateGetTupleElement(loop, 3)); + HloComputation* comp = dot->parent(); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( dot->ReplaceOperandWith(cache_output_index, new_gte)); TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR( - ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); - ag_loop.consumed = true; - } } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc index 23257e1c71a34b..6f23319980e90c 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc @@ -269,23 +269,22 @@ ENTRY main.12_spmd { FindInstructionByName(module->entry_computation(), "dot.7"); // dot.7 should now consume output of the windowed einsum while loop. EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); - EXPECT_EQ(inst->operand(0)->tuple_index(), 3); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); // while loop's root should now have a chain of DUS. HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction(); EXPECT_THAT(ag_while_root, GmockMatch(m::Tuple( - m::Op(), m::Op(), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op(), m::DynamicUpdateSlice( m::DynamicUpdateSlice( m::GetTupleElement(m::Parameter()) .WithPredicate([](const HloInstruction* instr) { - return instr->tuple_index() == 3; + return instr->tuple_index() == 5; }), m::Op(), m::Op(), m::Op(), m::Op()), - m::Op(), m::Op(), m::Op(), m::Op()), - m::Op()))); + m::Op(), m::Op(), m::Op(), m::Op())))); } TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( @@ -838,5 +837,82 @@ ENTRY main.9_spmd { )"); } +TEST_F(GpuWindowedEinsumHanlderTest, + AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 + +windowed_dot_general_body_ag { + param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0 + collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1 + get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2 + constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584}) + get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4 + partition-id.194 = u32[] partition-id() + add.4309 = u32[] add(get-tuple-element.592, partition-id.194) + constant.11431 = u32[] constant(8) + remainder.194 = u32[] remainder(add.4309, constant.11431) + dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1} + reshape.12959 = s32[] reshape(dynamic-slice.388) + constant.11433 = s32[] constant(0) + dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288} + dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244) + constant.11434 = u32[] constant(1) + add.4312 = u32[] add(get-tuple-element.592, constant.11434) + add.4313 = u32[] add(add.4312, partition-id.194) + remainder.195 = u32[] remainder(add.4313, constant.11431) + dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1} + reshape.12960 = s32[] reshape(dynamic-slice.390) + dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288} + dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245) + get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3 + add.4315 = u32[] add(add.4312, constant.11434) + ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(param), index=4 + constant = u32[] constant(4) + ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT +} + +ENTRY main.12_spmd { + param.4 = bf16[16,2048,512]{2,1,0} parameter(0) + param.5 = bf16[4096,6288]{1,0} parameter(1) + constant.22 = bf16[] constant(0) + broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={} + constant.24 = u32[] constant(0) + tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24) + while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 + all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true + param.6 = bf16[16,2048,6288]{2,1,0} parameter(2) + ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* ag_loop = + FindInstructionByName(module->entry_computation(), "while"); + HloInstruction* inst = + FindInstructionByName(module->entry_computation(), "dot.7"); + // dot.7 should now consume output of the windowed einsum while loop. + EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); + EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); +} } // namespace } // namespace xla::gpu