diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 78ae512eb54202..e57c2b30808d82 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -10,6 +10,10 @@ package( licenses = ["notice"], ) +exports_files(glob([ + "testdata/*.bin", +])) + package_group( name = "friends", packages = [ @@ -123,39 +127,39 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -163,12 +167,12 @@ tf_cc_test( ], deps = [ ":quantize_model", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -181,13 +185,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ # TODO(b/327796566): re-enable after the bug is fixed @@ -200,15 +204,28 @@ tf_cc_test( ], deps = [ ":quantize_weights", + ":test_util", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", - "//tensorflow/lite/tools/optimize:test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", "@local_tsl//tsl/platform:logging", ], ) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/lite:framework", + "//tensorflow/lite/core/api", + "@com_google_googletest//:gtest", + "@flatbuffers", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index e7d5e00b703392..1e7cdcdea07d33 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/lib/core/status_test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -192,7 +192,8 @@ void VerifyQuantizationScale( class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() { - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -277,7 +278,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, protected: QuantizeConvModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); // Flatbuffer is missing calibration data -- add dummy params. @@ -347,7 +349,7 @@ TEST_P(QuantizeConvModelTest, GraphIsFullyQuantized) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() { - input_model_ = ReadModel(internal::kConvModelWithNoBias); + input_model_ = ReadModel(::mlir::lite::internal::kConvModelWithNoBias); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -367,7 +369,7 @@ class QuantizeConvNoBiasModelTest : public QuantizeModelTest { class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() { - input_model_ = ReadModel(internal::kModelSplit); + input_model_ = ReadModel(::mlir::lite::internal::kModelSplit); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -452,7 +454,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, protected: QuantizeConvModel2Test() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConvModelWith0Plus10Weights); + input_model_ = + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); auto& subgraph = model_.subgraphs[0]; @@ -690,7 +693,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() { - input_model_ = ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -753,7 +757,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() { - input_model_ = ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5); + input_model_ = + ReadModel(::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -813,7 +818,7 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() { - input_model_ = ReadModel(internal::kMultiInputAddWithReshape); + input_model_ = ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -933,7 +938,7 @@ class QuantizeConstInputTest : public QuantizeModelTest, protected: QuantizeConstInputTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kConstInputAddModel); + input_model_ = ReadModel(::mlir::lite::internal::kConstInputAddModel); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -980,7 +985,7 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() { - input_model_ = ReadModel(internal::kModelWithArgMaxOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithArgMaxOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1025,7 +1030,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() { - input_model_ = ReadModel(internal::kLstmCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1037,7 +1042,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { /*allow_float=*/true, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1048,7 +1053,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() { - input_model_ = ReadModel(internal::kLstmCalibrated2); + input_model_ = ReadModel(::mlir::lite::internal::kLstmCalibrated2); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1061,7 +1066,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1072,7 +1077,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() { - input_model_ = ReadModel(internal::kUnidirectionalSequenceLstmCalibrated); + input_model_ = ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1086,7 +1092,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1097,7 +1103,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() { - input_model_ = ReadModel(internal::kSvdfCalibrated); + input_model_ = ReadModel(::mlir::lite::internal::kSvdfCalibrated); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1110,7 +1116,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { /*allow_float=*/false, TensorType_INT8, output_buffer_)); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1123,7 +1129,7 @@ class QuantizeFCTest : public QuantizeModelTest, protected: QuantizeFCTest() { disable_per_channel_quantization_for_dense_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithFCOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithFCOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1371,7 +1377,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() { - input_model_ = ReadModel(internal::kModelMixed); + input_model_ = ReadModel(::mlir::lite::internal::kModelMixed); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1409,7 +1415,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizePackTest : public QuantizeModelTest { protected: QuantizePackTest() { - input_model_ = ReadModel(internal::kModelPack); + input_model_ = ReadModel(::mlir::lite::internal::kModelPack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1526,14 +1532,15 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { Eq(input2->quantization->zero_point)); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() { - input_model_ = ReadModel(internal::kModelWithUnpack); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithUnpack); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1583,7 +1590,7 @@ class QuantizeBroadcastToModelTest protected: QuantizeBroadcastToModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithBroadcastToOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1646,7 +1653,7 @@ class QuantizeGatherNDModelTest protected: QuantizeGatherNDModelTest() { tensor_type_ = GetParam(); - input_model_ = ReadModel(internal::kModelWithGatherNDOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithGatherNDOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } @@ -1706,7 +1713,7 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() { - input_model_ = ReadModel(internal::kModelWithWhereOp); + input_model_ = ReadModel(::mlir::lite::internal::kModelWithWhereOp); readonly_model_ = input_model_->GetModel(); model_ = UnPackFlatBufferModel(*readonly_model_); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 2e80bcae7486b4..7a42e74c2619af 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/core/platform/init_main.h" @@ -33,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/tools/optimize/test_util.h" #include "tsl/platform/logging.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -59,25 +60,25 @@ std::unique_ptr CreateMutableModelFromFile(const Model* input_model) { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc similarity index 95% rename from tensorflow/lite/tools/optimize/test_util.cc rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc index 5ca45326d1dcad..e096868eec8807 100644 --- a/tensorflow/lite/tools/optimize/test_util.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/tools/optimize/test_util.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { const char* kConvModelWithMinus128Plus127Weights = "single_conv_weights_min_minus_127_max_plus_127.bin"; @@ -89,5 +89,5 @@ int FailOnErrorReporter::Report(const char* format, va_list args) { return 0; } } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir diff --git a/tensorflow/lite/tools/optimize/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h similarity index 93% rename from tensorflow/lite/tools/optimize/test_util.h rename to tensorflow/compiler/mlir/lite/quantization/lite/test_util.h index 11e7ef230910f2..b4e317c131888e 100644 --- a/tensorflow/lite/tools/optimize/test_util.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -12,13 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ -#define TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ #include "tensorflow/lite/core/api/error_reporter.h" -namespace tflite { -namespace optimize { +namespace mlir { +namespace lite { namespace internal { // Test model with a single convolution. // Floating point weights of the model are all integers and lie in @@ -132,12 +132,12 @@ extern const char* kQatModelWithFc; extern const char* kModelWithResourceVarsCalibrated; // An error reporter that fails on testing. -class FailOnErrorReporter : public ErrorReporter { +class FailOnErrorReporter : public tflite::ErrorReporter { public: int Report(const char* format, va_list args) override; }; } // namespace internal -} // namespace optimize -} // namespace tflite +} // namespace lite +} // namespace mlir -#endif // TENSORFLOW_LITE_TOOLS_OPTIMIZE_TEST_UTIL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ diff --git a/tensorflow/lite/tools/optimize/testdata/README.md b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/README.md rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/README.md diff --git a/tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/add_with_const_input.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/add_with_const_input.bin diff --git a/tensorflow/lite/tools/optimize/testdata/argmax.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/argmax.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/argmax.bin diff --git a/tensorflow/lite/tools/optimize/testdata/broadcast_to.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/broadcast_to.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/broadcast_to.bin diff --git a/tensorflow/lite/tools/optimize/testdata/concat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/concat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/concat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/custom_op.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/custom_op.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/custom_op.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc.bin diff --git a/tensorflow/lite/tools/optimize/testdata/fc_qat.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/fc_qat.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/fc_qat.bin diff --git a/tensorflow/lite/tools/optimize/testdata/gather_nd.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/gather_nd.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/gather_nd.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_calibrated2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_calibrated2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/lstm_quantized2.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/lstm_quantized2.bin diff --git a/tensorflow/lite/tools/optimize/testdata/maximum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/maximum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/maximum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/minimum.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/minimum.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/minimum.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed.bin diff --git a/tensorflow/lite/tools/optimize/testdata/mixed16x8.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/mixed16x8.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/mixed16x8.bin diff --git a/tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/multi_input_add_reshape.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/multi_input_add_reshape.bin diff --git a/tensorflow/lite/tools/optimize/testdata/pack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/pack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/pack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/quantized_with_gather.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/quantized_with_gather.bin diff --git a/tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/resource_vars_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/resource_vars_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_avg_pool_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_avg_pool_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_no_bias.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_no_bias.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_0_max_plus_10.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_0_max_plus_10.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_conv_weights_min_minus_127_max_plus_127.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_conv_weights_min_minus_127_max_plus_127.bin diff --git a/tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/single_softmax_min_minus_5_max_plus_5.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/single_softmax_min_minus_5_max_plus_5.bin diff --git a/tensorflow/lite/tools/optimize/testdata/split.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/split.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/split.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/svdf_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/svdf_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/transpose.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/transpose.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/transpose.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_calibrated.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_calibrated.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unidirectional_sequence_lstm_quantized.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unidirectional_sequence_lstm_quantized.bin diff --git a/tensorflow/lite/tools/optimize/testdata/unpack.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/unpack.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/unpack.bin diff --git a/tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/weight_shared_between_convs.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/weight_shared_between_convs.bin diff --git a/tensorflow/lite/tools/optimize/testdata/where.bin b/tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin similarity index 100% rename from tensorflow/lite/tools/optimize/testdata/where.bin rename to tensorflow/compiler/mlir/lite/quantization/lite/testdata/where.bin diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index 711f97bdfddd16..a05a5cbdb10710 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -14,10 +14,6 @@ package( licenses = ["notice"], ) -exports_files(glob([ - "testdata/*.bin", -])) - cc_library( name = "reduced_precision_support", srcs = [], @@ -39,7 +35,6 @@ tf_cc_test( ], deps = [ ":reduced_precision_support", - ":test_util", "//tensorflow/core/platform:platform_port", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", @@ -223,7 +218,6 @@ tf_cc_test( ], deps = [ ":model_utils", - ":test_util", "//tensorflow/lite:framework", "//tensorflow/lite/core:framework", "//tensorflow/lite/schema:schema_fbs", @@ -250,10 +244,10 @@ tf_cc_test( name = "quantization_utils_test", srcs = ["quantization_utils_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - ":testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", ], tags = [ "tflite_not_portable_android", @@ -261,7 +255,7 @@ tf_cc_test( ], deps = [ ":quantization_utils", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -316,13 +310,13 @@ tf_cc_test( name = "quantize_weights_test", srcs = ["quantize_weights_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/custom_op.bin", - "//tensorflow/lite/tools/optimize:testdata/quantized_with_gather.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/weight_shared_between_convs.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", ], tags = [ "tflite_not_portable_android", @@ -330,7 +324,7 @@ tf_cc_test( ], deps = [ ":quantize_weights", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", @@ -342,19 +336,6 @@ tf_cc_test( ], ) -cc_library( - name = "test_util", - testonly = 1, - srcs = ["test_util.cc"], - hdrs = ["test_util.h"], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/core/api", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - cc_library( name = "quantize_model", srcs = ["quantize_model.cc"], @@ -379,40 +360,40 @@ tf_cc_test( name = "quantize_model_test", srcs = ["quantize_model_test.cc"], args = [ - "--test_model_file=$(location //tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin)", + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", ], data = [ - "//tensorflow/lite/tools/optimize:testdata/add_with_const_input.bin", - "//tensorflow/lite/tools/optimize:testdata/argmax.bin", - "//tensorflow/lite/tools/optimize:testdata/broadcast_to.bin", - "//tensorflow/lite/tools/optimize:testdata/concat.bin", - "//tensorflow/lite/tools/optimize:testdata/fc.bin", - "//tensorflow/lite/tools/optimize:testdata/fc_qat.bin", - "//tensorflow/lite/tools/optimize:testdata/gather_nd.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_calibrated2.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/lstm_quantized2.bin", - "//tensorflow/lite/tools/optimize:testdata/maximum.bin", - "//tensorflow/lite/tools/optimize:testdata/minimum.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed.bin", - "//tensorflow/lite/tools/optimize:testdata/mixed16x8.bin", - "//tensorflow/lite/tools/optimize:testdata/multi_input_add_reshape.bin", - "//tensorflow/lite/tools/optimize:testdata/pack.bin", - "//tensorflow/lite/tools/optimize:testdata/resource_vars_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_no_bias.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_0_max_plus_10.bin", - "//tensorflow/lite/tools/optimize:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", - "//tensorflow/lite/tools/optimize:testdata/single_softmax_min_minus_5_max_plus_5.bin", - "//tensorflow/lite/tools/optimize:testdata/split.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/svdf_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/transpose.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_calibrated.bin", - "//tensorflow/lite/tools/optimize:testdata/unidirectional_sequence_lstm_quantized.bin", - "//tensorflow/lite/tools/optimize:testdata/unpack.bin", - "//tensorflow/lite/tools/optimize:testdata/where.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/add_with_const_input.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/argmax.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/broadcast_to.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/concat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/fc_qat.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/gather_nd.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_calibrated2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/lstm_quantized2.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/maximum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/minimum.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/mixed16x8.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/multi_input_add_reshape.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/pack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/resource_vars_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_avg_pool_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_no_bias.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_minus_127_max_plus_127.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_softmax_min_minus_5_max_plus_5.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/split.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/svdf_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/transpose.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_calibrated.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unidirectional_sequence_lstm_quantized.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/unpack.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/where.bin", ], tags = [ "tflite_not_portable_android", @@ -420,7 +401,7 @@ tf_cc_test( ], deps = [ ":quantize_model", - ":test_util", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/lite:framework", diff --git a/tensorflow/lite/tools/optimize/model_utils_test.cc b/tensorflow/lite/tools/optimize/model_utils_test.cc index 65e3afe35e2da2..f702e1fa0a0ddd 100644 --- a/tensorflow/lite/tools/optimize/model_utils_test.cc +++ b/tensorflow/lite/tools/optimize/model_utils_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/tensorflow/lite/tools/optimize/quantization_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_utils_test.cc index a09acef6f4aa3c..a0ab9c43eacb75 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include "absl/memory/memory.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -46,7 +46,7 @@ std::unique_ptr ReadModel(const char* model) { } std::unique_ptr ReadConvModel() { - return ReadModel(internal::kConvModelWith0Plus10Weights); + return ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights); } using ::testing::ElementsAreArray; diff --git a/tensorflow/lite/tools/optimize/quantize_model_test.cc b/tensorflow/lite/tools/optimize/quantize_model_test.cc index 681507c8e0d31d..a7e9115f8bdaaa 100644 --- a/tensorflow/lite/tools/optimize/quantize_model_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" // Note: More rigorous model tests can be found in subgraph_quantizer_test.cc @@ -78,7 +78,8 @@ TensorType GetBiasTensorType(TensorType& activation_type) { class QuantizeModelTest : public testing::Test { protected: QuantizeModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) {} + : QuantizeModelTest( + ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights)) {} explicit QuantizeModelTest(std::unique_ptr input_model) { input_model_ = std::move(input_model); @@ -91,7 +92,7 @@ class QuantizeModelTest : public testing::Test { const Model* readonly_model_; tflite::ModelT model_; flatbuffers::FlatBufferBuilder builder_; - internal::FailOnErrorReporter error_reporter_; + ::mlir::lite::internal::FailOnErrorReporter error_reporter_; }; void ExpectSameModels(const ModelT& model, const ModelT& expected_model) { @@ -136,7 +137,8 @@ class QuantizeConvModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} TensorType tensor_type_; @@ -405,7 +407,8 @@ TEST_P(QuantizeConvModelTest, Uint8InputAndOutput) { class QuantizeConvNoBiasModelTest : public QuantizeModelTest { protected: QuantizeConvNoBiasModelTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWithNoBias)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWithNoBias)) {} }; TEST_F(QuantizeConvNoBiasModelTest, QuantizationSucceeds) { @@ -422,7 +425,8 @@ class QuantizeConcatModelTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConcatModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) {} void SetUp() override { tensor_type_ = GetParam(); @@ -536,7 +540,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeConcatModelInst, QuantizeConcatModelTest, class QuantizeSplitModelTest : public QuantizeModelTest { protected: QuantizeSplitModelTest() - : QuantizeModelTest(ReadModel(internal::kModelSplit)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelSplit)) {} }; // There are two outputs for split with different scales, the resulting model @@ -601,8 +605,8 @@ TEST_F(QuantizeSplitModelTest, QuantizeSplit) { class QuantizeConvModel1Test : public QuantizeModelTest { protected: QuantizeConvModel1Test() - : QuantizeModelTest( - ReadModel(internal::kConvModelWithMinus128Plus127Weights)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kConvModelWithMinus128Plus127Weights)) {} }; TEST_F(QuantizeConvModel1Test, VerifyConvQuantizationWithUnitScale) { @@ -703,7 +707,8 @@ class QuantizeConvModel2Test : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConvModel2Test() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -925,8 +930,8 @@ TEST_P(QuantizeConvModel2Test, VerifyConvDisablePerChannelQuantization) { class QuantizeSoftmaxTest : public QuantizeModelTest { protected: QuantizeSoftmaxTest() - : QuantizeModelTest( - ReadModel(internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleSoftmaxModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { @@ -985,8 +990,8 @@ TEST_F(QuantizeSoftmaxTest, VerifySoftmaxQuantization) { class QuantizeAvgPoolTest : public QuantizeModelTest { protected: QuantizeAvgPoolTest() - : QuantizeModelTest( - ReadModel(internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kSingleAvgPoolModelMinMinus5MaxPlus5)) {} }; TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { @@ -1045,7 +1050,8 @@ TEST_F(QuantizeAvgPoolTest, VerifyAvgPoolQuantization) { class QuantizeMultiInputAddWithReshapeTest : public QuantizeModelTest { protected: QuantizeMultiInputAddWithReshapeTest() - : QuantizeModelTest(ReadModel(internal::kMultiInputAddWithReshape)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kMultiInputAddWithReshape)) {} }; TEST_F(QuantizeMultiInputAddWithReshapeTest, VerifyReshapeQuantization) { @@ -1155,7 +1161,8 @@ class QuantizeConstInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: QuantizeConstInputTest() - : QuantizeModelTest(ReadModel(internal::kConstInputAddModel)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConstInputAddModel)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1213,7 +1220,8 @@ TEST_P(QuantizeConstInputTest, VerifyConstOpInput) { class QuantizeArgMaxTest : public QuantizeModelTest { protected: QuantizeArgMaxTest() - : QuantizeModelTest(ReadModel(internal::kModelWithArgMaxOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithArgMaxOp)) {} }; TEST_F(QuantizeArgMaxTest, VerifyArgMax) { @@ -1254,7 +1262,7 @@ TEST_F(QuantizeArgMaxTest, VerifyArgMax) { class QuantizeLSTMTest : public QuantizeModelTest { protected: QuantizeLSTMTest() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated)) {} }; TEST_F(QuantizeLSTMTest, VerifyLSTM) { @@ -1265,7 +1273,7 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1276,7 +1284,8 @@ TEST_F(QuantizeLSTMTest, VerifyLSTM) { class QuantizeLSTM2Test : public QuantizeModelTest { protected: QuantizeLSTM2Test() - : QuantizeModelTest(ReadModel(internal::kLstmCalibrated2)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kLstmCalibrated2)) { + } }; TEST_F(QuantizeLSTM2Test, VerifyLSTM) { @@ -1287,7 +1296,7 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kLstmQuantized2); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kLstmQuantized2); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1298,8 +1307,8 @@ TEST_F(QuantizeLSTM2Test, VerifyLSTM) { class QuantizeUnidirectionalSequenceLSTMTest : public QuantizeModelTest { protected: QuantizeUnidirectionalSequenceLSTMTest() - : QuantizeModelTest( - ReadModel(internal::kUnidirectionalSequenceLstmCalibrated)) {} + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kUnidirectionalSequenceLstmCalibrated)) {} }; TEST_F(QuantizeUnidirectionalSequenceLSTMTest, @@ -1312,7 +1321,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, // Read expected model. auto expected_fb_model = - ReadModel(internal::kUnidirectionalSequenceLstmQuantized); + ReadModel(::mlir::lite::internal::kUnidirectionalSequenceLstmQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1323,7 +1332,7 @@ TEST_F(QuantizeUnidirectionalSequenceLSTMTest, class QuantizeSVDFTest : public QuantizeModelTest { protected: QuantizeSVDFTest() - : QuantizeModelTest(ReadModel(internal::kSvdfCalibrated)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kSvdfCalibrated)) {} }; TEST_F(QuantizeSVDFTest, VerifySVDF) { @@ -1334,7 +1343,7 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { ASSERT_EQ(kTfLiteOk, status); // Read expected model. - auto expected_fb_model = ReadModel(internal::kSvdfQuantized); + auto expected_fb_model = ReadModel(::mlir::lite::internal::kSvdfQuantized); auto expected_read_only_model = expected_fb_model->GetModel(); ModelT expected_model; expected_read_only_model->UnPackTo(&expected_model); @@ -1379,7 +1388,8 @@ TEST_F(QuantizeSVDFTest, VerifySVDF) { class QuantizeFCTest : public QuantizeModelTest { protected: - QuantizeFCTest() : QuantizeModelTest(ReadModel(internal::kModelWithFCOp)) {} + QuantizeFCTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithFCOp)) {} }; TEST_F(QuantizeFCTest, VerifyFC) { @@ -1430,7 +1440,7 @@ class QuantizeCustomOpTest public ::testing::WithParamInterface { protected: QuantizeCustomOpTest() - : QuantizeModelTest(ReadModel(internal::kModelMixed)), + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1471,7 +1481,7 @@ INSTANTIATE_TEST_SUITE_P(QuantizeCustomOpTest, QuantizeCustomOpTest, class QuantizeOp16x8Test : public QuantizeModelTest { protected: QuantizeOp16x8Test() - : QuantizeModelTest(ReadModel(internal::kModelMixed16x8)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelMixed16x8)) {} }; TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { @@ -1502,7 +1512,8 @@ TEST_F(QuantizeOp16x8Test, VerifyMixedQuantization16x8) { class QuantizePackTest : public QuantizeModelTest { protected: - QuantizePackTest() : QuantizeModelTest(ReadModel(internal::kModelPack)) {} + QuantizePackTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelPack)) {} }; TEST_F(QuantizePackTest, VerifyPack) { @@ -1628,14 +1639,16 @@ TEST_P(QuantizeMinimumMaximumTest, VerifyMinimumMaximum) { EXPECT_EQ(subgraph->tensors[5]->name, "output"); } -INSTANTIATE_TEST_SUITE_P(MinimumMaximumTestInst, QuantizeMinimumMaximumTest, - testing::ValuesIn({internal::kModelWithMinimumOp, - internal::kModelWithMaximumOp})); +INSTANTIATE_TEST_SUITE_P( + MinimumMaximumTestInst, QuantizeMinimumMaximumTest, + testing::ValuesIn({::mlir::lite::internal::kModelWithMinimumOp, + ::mlir::lite::internal::kModelWithMaximumOp})); class QuantizeUnpackTest : public QuantizeModelTest { protected: QuantizeUnpackTest() - : QuantizeModelTest(ReadModel(internal::kModelWithUnpack)) {} + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kModelWithUnpack)) { + } }; TEST_F(QuantizeUnpackTest, VerifyUnpack) { auto status = QuantizeModel(&builder_, &model_, &error_reporter_); @@ -1680,7 +1693,8 @@ TEST_F(QuantizeUnpackTest, VerifyUnpack) { class QuantizeTransposeTest : public QuantizeModelTest { protected: QuantizeTransposeTest() - : QuantizeModelTest(ReadModel(internal::kModelWithTranspose)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithTranspose)) {} }; TEST_F(QuantizeTransposeTest, VerifyTranspose) { @@ -1720,7 +1734,8 @@ TEST_F(QuantizeTransposeTest, VerifyTranspose) { class QuantizeQatTest : public QuantizeModelTest { protected: - QuantizeQatTest() : QuantizeModelTest(ReadModel(internal::kQatModelWithFc)) {} + QuantizeQatTest() + : QuantizeModelTest(ReadModel(::mlir::lite::internal::kQatModelWithFc)) {} }; TEST_F(QuantizeQatTest, VerifySingleQuantize) { @@ -1777,7 +1792,8 @@ class QuantizeBroadcastToModelTest public testing::WithParamInterface { protected: QuantizeBroadcastToModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithBroadcastToOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithBroadcastToOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1844,7 +1860,8 @@ class QuantizeGatherNDModelTest public testing::WithParamInterface { protected: QuantizeGatherNDModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithGatherNDOp)), + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithGatherNDOp)), tensor_type_(GetParam()), bias_type_(GetBiasTensorType(tensor_type_)) {} @@ -1906,7 +1923,8 @@ TEST_P(QuantizeGatherNDModelTest, QuantizeGatherND) { class QuantizeWhereModelTest : public QuantizeModelTest { protected: QuantizeWhereModelTest() - : QuantizeModelTest(ReadModel(internal::kModelWithWhereOp)) {} + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kModelWithWhereOp)) {} }; TEST_F(QuantizeWhereModelTest, QuantizeWhere) { @@ -1976,8 +1994,8 @@ class QuantizeResourcesModelTest public testing::WithParamInterface { protected: QuantizeResourcesModelTest() - : QuantizeModelTest( - ReadModel(internal::kModelWithResourceVarsCalibrated)) { + : QuantizeModelTest(ReadModel( + ::mlir::lite::internal::kModelWithResourceVarsCalibrated)) { TestType obj = GetParam(); tensor_type_ = obj.tensor_type; modify_range_ = obj.modify_range; @@ -2119,7 +2137,8 @@ class QuantizeConcatConstModelTest public testing::WithParamInterface { protected: QuantizeConcatConstModelTest() - : QuantizeModelTest(ReadModel(internal::kFloatConcatMax5Max10Max10)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kFloatConcatMax5Max10Max10)) { // Make one of the values constant. MakeInputConstant(&model_); } @@ -2224,7 +2243,8 @@ class BiasInputTest : public QuantizeModelTest, public testing::WithParamInterface { protected: BiasInputTest() - : QuantizeModelTest(ReadModel(internal::kConvModelWith0Plus10Weights)) { + : QuantizeModelTest( + ReadModel(::mlir::lite::internal::kConvModelWith0Plus10Weights)) { BiasTestType obj = GetParam(); tensor_type_ = obj.tensor_type; bias_type_ = obj.bias_type; diff --git a/tensorflow/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/lite/tools/optimize/quantize_weights_test.cc index 0e9c3efc17acd9..b2279ed34908f6 100644 --- a/tensorflow/lite/tools/optimize/quantize_weights_test.cc +++ b/tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/flexbuffers.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/lite/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace { tensorflow::string* g_test_model_dir = nullptr; @@ -40,25 +40,25 @@ namespace { std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( - *g_test_model_dir, internal::kConvModelWith0Plus10Weights); + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadSharedWeightsTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kModelWithSharedWeights); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadGatherTestModel() { - auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, - internal::kQuantizedWithGather); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); return FlatBufferModel::BuildFromFile(model_path.c_str()); } std::unique_ptr ReadCustomOpTestModel() { - auto model_path = - tensorflow::io::JoinPath(*g_test_model_dir, internal::kModelWithCustomOp); + auto model_path = tensorflow::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); return FlatBufferModel::BuildFromFile(model_path.c_str()); } diff --git a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc index 6b5cf538b50c43..19400079b17e96 100644 --- a/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc +++ b/tensorflow/lite/tools/optimize/reduced_precision_support_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_utils.h" #include "tensorflow/lite/testing/util.h" -#include "tensorflow/lite/tools/optimize/test_util.h" namespace tflite { namespace optimize { diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index dd13ffebd9e852..641fedf0c72405 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -8188,6 +8188,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleSelect(HloInstruction* select) { select->mutable_operand(0)->shape(), HloOpcode::kNot, select->mutable_operand(0))); } + // select(compare(a, b, GT/GE), a, b) => or(a, b) + // select(compare(a, b, LT/LE), a, b) => and(a, b) + // select(compare(a, b, EQ), a, b) => b + // select(compare(a, b, NE), a, b) => a + HloInstruction *compare, *lhs, *rhs; + if (Match(select, m::Select(m::Op(&compare), m::Op(&lhs), m::Op(&rhs))) && + Match(compare, m::Compare(m::Op().Is(lhs), m::Op().Is(rhs)))) { + auto cmp_dir = compare->comparison_direction(); + if (cmp_dir == ComparisonDirection::kGt || + cmp_dir == ComparisonDirection::kGe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kOr, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kLt || + cmp_dir == ComparisonDirection::kLe) { + return ReplaceWithNewInstruction( + select, HloInstruction::CreateBinary(select->shape(), + HloOpcode::kAnd, lhs, rhs)); + } + if (cmp_dir == ComparisonDirection::kEq) { + return ReplaceInstruction(select, rhs); + } + if (cmp_dir == ComparisonDirection::kNe) { + return ReplaceInstruction(select, lhs); + } + } } // select(pred, xs, dynamic_update_slice(xs, x, i)) diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 00970a51546b1a..921098aa7565e8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -736,6 +736,95 @@ TEST_F(AlgebraicSimplifierTest, SelectPredPred2) { GmockMatch(m::Not(m::Parameter(0)))); } +// select(compare(a, b, GT/GE), a, b) => or(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectGtCompare) { + for (const auto cmp_dir : {"GT", "GE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Or(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, LT/LE), a, b) => and(a, b), a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectLtCompare) { + for (const auto cmp_dir : {"LT", "LE"}) { + const auto kModuleStr = absl::StrFormat(R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=%s + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )", + cmp_dir); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::And(m::Parameter(0), m::Parameter(1)))); + } +} + +// select(compare(a, b, EQ), a, b) => b, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectEqCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=EQ + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(1))); +} + +// select(compare(a, b, NE), a, b) => a, a,b ∈ PRED +TEST_F(AlgebraicSimplifierTest, SelectNeCompare) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p0, p1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +// select(compare(a, b, NE), b, a) ≠> a - wrong operands order +TEST_F(AlgebraicSimplifierTest, SelectNeCompare_NegativeTestCase) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[8]{0} parameter(0) + p1 = pred[8]{0} parameter(1) + compare = pred[8]{0} compare(p0, p1), direction=NE + ROOT select = pred[8]{0} select(compare, p1, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); +} + // Test that select(pred, xs, dynamic_update_slice(xs, x, i)) is simplified // to dynamic_update_slice(xs, select(pred, dynamic_slice(xs, i), x), i) TEST_F(AlgebraicSimplifierTest, SelectDUSWithShapedPred) { diff --git a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc index bbc32250444b0f..b03df23e41e764 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/select_and_scatter_benchmark_test.cc @@ -76,7 +76,6 @@ BENCHMARK(BM_SelectAndScatterF32) ->Arg(64) ->Arg(128) ->Arg(256) - ->Arg(512) - ->Arg(1024); + ->Arg(512); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/service/cpu/runtime/BUILD index 8c37031edf7985..f9ad7ef5c51300 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/service/cpu/runtime/BUILD @@ -17,21 +17,36 @@ package_group( cc_library( name = "buffer_allocations", - srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], deps = [ + "//xla:util", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", ], ) +xla_cc_test( + name = "buffer_allocations_test", + srcs = ["buffer_allocations_test.cc"], + deps = [ + ":buffer_allocations", + "//xla/service:buffer_assignment", + "//xla/service:maybe_owning_device_memory", + "//xla/stream_executor:device_memory", + "@com_google_absl//absl/status", + "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "task", hdrs = ["task.h"], diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc deleted file mode 100644 index e35b931c08e5bc..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/runtime/buffer_allocations.h" - -#include - -#include "absl/base/optimization.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/service/buffer_assignment.h" -#include "xla/stream_executor/device_memory.h" -#include "tsl/platform/statusor.h" - -namespace xla::cpu { - -absl::StatusOr BufferAllocations::GetDeviceAddress( - BufferAllocation::Index buffer_index) const { - if (ABSL_PREDICT_FALSE(buffer_index < 0 || buffer_index >= buffers_.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Invalid buffer_index ", buffer_index, - " value. It must be in the range [0, ", buffers_.size(), ")")); - } - - return buffers_[buffer_index].AsDeviceMemoryBase(); -} - -absl::StatusOr BufferAllocations::GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice) const { - // Handle empty slices explicitly and return a null pointer device memory to - // guarantee that we do not accidentally write through the empty slice which - // would hide a real bug in the code. - if (ABSL_PREDICT_FALSE(buffer_slice.size() == 0)) { - return se::DeviceMemoryBase(nullptr, 0); - } - - int64_t index = buffer_slice.index(); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase base, GetDeviceAddress(index)); - - int64_t offset = buffer_slice.offset(); - int64_t extent = offset + buffer_slice.size(); - - if (ABSL_PREDICT_FALSE(offset < 0)) { - return absl::InvalidArgumentError( - absl::StrCat("Buffer slice offset ", offset, " must be non-negative")); - } - - if (ABSL_PREDICT_FALSE(offset >= base.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Buffer slice offset ", offset, " is out of range for buffer #", index, - " of size ", base.size())); - } - - if (ABSL_PREDICT_FALSE(extent > base.size())) { - return absl::InvalidArgumentError(absl::StrCat( - "Buffer slice extent ", extent, " is out of range for buffer #", index, - " of size ", base.size())); - } - - return base.GetByteSlice(offset, buffer_slice.size()); -} - -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h index 76f05390a01b07..7abcff73fb5b66 100644 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h +++ b/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h @@ -16,39 +16,102 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ #define XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" +#include "xla/util.h" namespace xla::cpu { // Buffer allocation is a container for device buffers allocated for a // particular XLA execution. Buffers are indexed by the buffer allocation index. -// -// TODO(b/342513610): BufferAllocations should be unified with a same class in -// the XLA:GPU runtime, probably as a part of `buffer_assignment.h`. class BufferAllocations { public: - explicit BufferAllocations(absl::Span buffers) - : buffers_(buffers) {} + explicit inline BufferAllocations( + absl::Span buffers); // Returns the device address of buffer `buffer_index`. `buffer_index` must be // a valid index, i.e., in [0, buffer_count). - absl::StatusOr GetDeviceAddress( - BufferAllocation::Index buffer_index) const; + inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr + GetDeviceAddress(BufferAllocation::Index buffer_index) const; // Same as above, but also adjusts the returned address for the offset and // size contained in the given slice. - absl::StatusOr GetDeviceAddress( - const BufferAllocation::Slice& buffer_slice) const; + inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr + GetDeviceAddress(const BufferAllocation::Slice& buffer_slice) const; private: - // TODO(ezhulenev): Make BufferAllocations an owner of the buffers. - absl::Span buffers_; // not owned + std::vector buffers_; + size_t num_buffers_; }; +BufferAllocations::BufferAllocations( + absl::Span buffers) + : buffers_(buffers.size()), num_buffers_(buffers_.size()) { + for (size_t i = 0; i < buffers.size(); ++i) { + buffers_[i] = buffers[i].AsDeviceMemoryBase(); + } +} + +absl::StatusOr BufferAllocations::GetDeviceAddress( + BufferAllocation::Index index) const { + if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) { + return InvalidArgument( + "Invalid buffer index %d. It must be in the range [0, %d)", index, + num_buffers_); + } + + return buffers_[index]; +} + +absl::StatusOr BufferAllocations::GetDeviceAddress( + const BufferAllocation::Slice& buffer_slice) const { + // Handle empty slices explicitly and return a null pointer device memory to + // guarantee that we do not accidentally write through the empty slice which + // would hide a real bug in the code. + if (ABSL_PREDICT_FALSE(buffer_slice.size() == 0)) { + return se::DeviceMemoryBase(nullptr, 0); + } + + int64_t index = buffer_slice.index(); + if (ABSL_PREDICT_FALSE(index < 0 || index >= num_buffers_)) { + return InvalidArgument( + "Invalid buffer index %d. It must be in the range [0, %d)", index, + num_buffers_); + } + const se::DeviceMemoryBase& base = buffers_[index]; + + int64_t offset = buffer_slice.offset(); + int64_t extent = offset + buffer_slice.size(); + + if (ABSL_PREDICT_FALSE(offset < 0)) { + return InvalidArgument("Buffer slice offset %d must be non-negative", + offset); + } + + if (ABSL_PREDICT_FALSE(offset >= base.size())) { + return InvalidArgument( + "Buffer slice offset %d is out of range for buffer #%d of size %d", + offset, index, base.size()); + } + + if (ABSL_PREDICT_FALSE(extent > base.size())) { + return InvalidArgument( + "Buffer slice extent %d is out of range for buffer #%d of size %d", + extent, index, base.size()); + } + + return base.GetByteSlice(offset, buffer_slice.size()); +} + } // namespace xla::cpu #endif // XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc b/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc new file mode 100644 index 00000000000000..f281924e2542ac --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/buffer_allocations.h" + +#include +#include + +#include "xla/service/buffer_assignment.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/stream_executor/device_memory.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(BufferAllocationsTest, GetDeviceAddress) { + std::vector buffers; + std::vector data = {1.0, 2.0, 3.0, 4.0}; + + size_t size_in_bytes = data.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation alloc(0, size_in_bytes, 0); + BufferAllocation::Slice slice(&alloc, /*offset=*/2 * sizeof(float), + /*size=*/sizeof(float)); + + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase alloc_mem, + allocations.GetDeviceAddress(0)); + EXPECT_EQ(alloc_mem.opaque(), &data[0]); + + TF_ASSERT_OK_AND_ASSIGN(se::DeviceMemoryBase slice_mem, + allocations.GetDeviceAddress(slice)); + EXPECT_EQ(slice_mem.opaque(), &data[2]); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc index a8d793d1076071..21c57fef35f940 100644 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc @@ -87,40 +87,46 @@ tsl::AsyncValueRef KernelThunk::Execute( kernel_name_, arguments_buffers_.size(), results_buffers_.size(), thread_dim_.ToString()); - absl::InlinedVector kernel_args; - kernel_args.reserve(arguments_buffers_.size() + results_buffers_.size()); + int64_t num_args = arguments_buffers_.size() + results_buffers_.size(); + absl::InlinedVector kernel_args(num_args); + + // We initialize `kernel_args` array using pointer to the first argument, + // because individual elements access adds up measurable overhead, and this + // code is on the critical path. + SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); + int64_t kernel_arg_idx = 0; int64_t arg_num = 0; for (BufferAllocation::Slice& buffer : arguments_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase arg_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}); VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", arg_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), arg_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{arg_data.opaque(), arg_data.size()}; } int64_t res_num = 0; for (BufferAllocation::Slice& buffer : results_buffers_) { TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result_data, params.buffer_allocations->GetDeviceAddress(buffer)); - kernel_args.push_back( - SE_HOST_KernelArg{result_data.opaque(), result_data.size()}); VLOG(3) << absl::StreamFormat(" res #%d: %s (%p)", res_num++, - buffer.ToString(), kernel_args.back().data); + buffer.ToString(), result_data.opaque()); + kernel_args_ptr[kernel_arg_idx++] = + SE_HOST_KernelArg{result_data.opaque(), result_data.size()}; } // Check that all buffers are aligned to the minimum alignment. We codegen // with the assumption that all buffers are aligned, and if they are not, we // will crash with a segmentation fault, or worse, produce incorrect results. if (min_alignment_.has_value()) { - for (int64_t i = 0; i < kernel_args.size(); ++i) { - auto ptr = reinterpret_cast(kernel_args[i].data); + for (int64_t i = 0; i < num_args; ++i) { + auto ptr = reinterpret_cast(kernel_args_ptr[i].data); if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { return Internal( "Host kernel %s buffer argument #%d (%p) is not aligned to a " "required minimum alignment of %d bytes", - info().op_name, i, kernel_args[i].data, *min_alignment_); + info().op_name, i, kernel_args_ptr[i].data, *min_alignment_); } } } @@ -136,7 +142,7 @@ tsl::AsyncValueRef KernelThunk::Execute( params.host_kernels->Find(kernel_name_)); absl::MutexLock lock(&mutex_); - kernel_.emplace(kernel_args.size(), kernel_fn, nullptr); + kernel_.emplace(num_args, kernel_fn, nullptr); kernel_ptr_.store(kernel = &kernel_.value()); } diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc index ce1dfa4f1c7863..8f5e26124f24a4 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc @@ -378,8 +378,16 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } +static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { + const HloInstruction* loop_tuple = while_loop->operand(0); + const Shape& tuple_shape = loop_tuple->shape(); + CHECK(tuple_shape.IsTuple()); + return tuple_shape.tuple_shapes_size(); +} + absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { + GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, + HloInstruction* ag_with_shared_operand) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -392,15 +400,61 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( } // Get the output operand of the full buffer. HloInstruction* root = while_body->root_instruction(); + // Change loop body to include the new input and output element. + HloInstruction* input_tuple = while_body->parameter_instruction(0); + const Shape& input_shape = input_tuple->shape(); // The full buffer that we will use to cache the accumulated activation - // is the 4th operand in the output tuple. - int64_t full_cache_buffer_index = 3; + // is the last operand in the output tuple. + int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); + std::vector new_input_shapes(input_shape.tuple_shapes().begin(), + input_shape.tuple_shapes().end()); + new_input_shapes.push_back(ag_with_shared_operand->shape()); + // Update body input shape + Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); + *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = - root->mutable_operand(full_cache_buffer_index); - HloInstruction* new_full_buffer_output; + while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + ag_with_shared_operand->shape(), input_tuple, + full_cache_buffer_index)); + + // Update condition input shape + HloComputation* cond_comp = loop->while_condition(); + HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); + *cond_input_tuple->mutable_shape() = new_input_shape; + + // Update input to the while instruction in parent computation + HloInstruction* original_while_input = loop->mutable_operand(0); + HloComputation* parent_comp = loop->parent(); + std::vector new_operands( + original_while_input->operands().begin(), + original_while_input->operands().end()); + new_operands.push_back( + parent_comp->AddInstruction(HloInstruction::CreateBroadcast( + ag_with_shared_operand->shape(), + parent_comp->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::Zero(new_input_shapes[0].element_type()))), + {}))); + HloInstruction* new_while_input = + parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR( + loop->ReplaceOperandWithDifferentShape(0, new_while_input)); + TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( + original_while_input, new_while_input)); + *loop->mutable_shape() = new_input_shape; + + HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices // This should just be a constant(0) HloInstruction* dus_boundary_constant; + // The slice we need this time is the output of the first + // collective-permute + HloInstruction* first_cp_output; + for (HloInstruction* gte_user : input_gte->users()) { + if (gte_user->opcode() == HloOpcode::kCollectivePermute) { + first_cp_output = gte_user; + break; + } + } for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) { HloInstruction* slice_indices; // If we have a DUS(PARAM,DS) pattern, we need to update the output @@ -434,24 +488,68 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( dus_boundary_constant->shape(), slice_indices)); VLOG(5) << "Created slice op for second slice: " << slice_indices->ToString(); - // The slice we need this time is the output of the first - // collective-permute - HloInstruction* cp_output; - for (HloInstruction* gte_user : input_gte->users()) { - if (gte_user->opcode() == HloOpcode::kCollectivePermute) { - cp_output = gte_user; - break; - } - } new_full_buffer_output = while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( full_buffer_output_gte->shape(), full_buffer_output_gte, - cp_output, + first_cp_output, {dus_boundary_constant, slice_indices, dus_boundary_constant})); } + + // If we have a Dot(DS(parameter_index1)), then operands are sharded along + // the contracting dim. Slice indices will be the contracting dim's slices. + HloInstruction* slice_index; + HloInstruction* ds_index_constant; + HloInstruction* remainder; + HloInstruction* ds_param; + // There will be 2 dynamic-slices for unrolled loops, match for each one to + // get the slice index which will be used to write the corresponding + // received shard into cached activation buffer. For unrolled loops, we need + // to write to the final buffer twice per iteration, so we need to match for + // the correct slice index based on each DS. + if (Match(inst, m::Dot(m::Op(), m::DynamicSlice(&ds_param))) && + Match(ds_param->operand(0), m::GetTupleElement(m::Parameter(), 1))) { + for (int64_t ds_op_i = 1; ds_op_i < ds_param->operands().size(); + ds_op_i++) { + if (!Match( + ds_param->mutable_operand(ds_op_i), + m::Reshape(&slice_index, m::DynamicSlice(m::Constant(), + m::Op(&remainder)))) && + !Match(ds_param->mutable_operand(ds_op_i), + m::Constant(&ds_index_constant))) { + return absl::OkStatus(); + } + } + // First DS has slice index calculated based on loop iterator + // Remainder(add(gte, partition_id)) + if (Match(remainder, + m::Remainder(m::Add(m::GetTupleElement(), m::Op()), m::Op()))) { + full_buffer_output_gte = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + input_gte, + {ds_index_constant, ds_index_constant, slice_index})); + } + // Second DS has slice index calculated based on loop iterator+1 hence + // Remainder(add(add(gte, 1), partition_id)) + if (Match(remainder, + m::Remainder( + m::Add(m::Add(m::GetTupleElement(), m::Op()), m::Op()), + m::Op()))) { + new_full_buffer_output = + while_body->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_buffer_output_gte->shape(), full_buffer_output_gte, + first_cp_output, + {ds_index_constant, ds_index_constant, slice_index})); + } + } } - TF_RETURN_IF_ERROR(root->ReplaceOperandWith(full_cache_buffer_index, - new_full_buffer_output)); + std::vector original_operands(root->operands().begin(), + root->operands().end()); + original_operands.push_back(new_full_buffer_output); + HloInstruction* new_output_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(original_operands)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); return absl::OkStatus(); } @@ -620,17 +718,20 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); + + if (!ag_loop.consumed) { + TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( + ag_loop, ag_with_shared_operand)); + ag_loop.consumed = true; + } int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloInstruction* new_gte = comp->AddInstruction( - HloInstruction::CreateGetTupleElement(loop, 3)); + HloComputation* comp = dot->parent(); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( dot->ReplaceOperandWith(cache_output_index, new_gte)); TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR( - ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); - ag_loop.consumed = true; - } } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc index 23257e1c71a34b..6f23319980e90c 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc @@ -269,23 +269,22 @@ ENTRY main.12_spmd { FindInstructionByName(module->entry_computation(), "dot.7"); // dot.7 should now consume output of the windowed einsum while loop. EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); - EXPECT_EQ(inst->operand(0)->tuple_index(), 3); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); // while loop's root should now have a chain of DUS. HloInstruction* ag_while_root = ag_loop->while_body()->root_instruction(); EXPECT_THAT(ag_while_root, GmockMatch(m::Tuple( - m::Op(), m::Op(), m::Op(), + m::Op(), m::Op(), m::Op(), m::Op(), m::Op(), m::DynamicUpdateSlice( m::DynamicUpdateSlice( m::GetTupleElement(m::Parameter()) .WithPredicate([](const HloInstruction* instr) { - return instr->tuple_index() == 3; + return instr->tuple_index() == 5; }), m::Op(), m::Op(), m::Op(), m::Op()), - m::Op(), m::Op(), m::Op(), m::Op()), - m::Op()))); + m::Op(), m::Op(), m::Op(), m::Op())))); } TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( @@ -838,5 +837,82 @@ ENTRY main.9_spmd { )"); } +TEST_F(GpuWindowedEinsumHanlderTest, + AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { + constexpr absl::string_view kHloString = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 + +windowed_dot_general_body_ag { + param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element.588 = bf16[16,2048,512]{2,1,0} get-tuple-element(param.195), index=0 + collective-permute.194 = bf16[16,2048,512]{2,1,0} collective-permute(get-tuple-element.588), channel_id=446, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + collective-permute.195 = bf16[16,2048,512]{2,1,0} collective-permute(collective-permute.194), channel_id=447, source_target_pairs={{0,7},{1,0},{2,1},{3,2},{4,3},{5,4},{6,5},{7,6}} + get-tuple-element.589 = bf16[4096,6288]{1,0} get-tuple-element(param.195), index=1 + get-tuple-element.590 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=2 + constant.11432 = s32[8]{0} constant({0, 512, 1024, 1536, 2048, 2560, 3072, 3584}) + get-tuple-element.592 = u32[] get-tuple-element(param.195), index=4 + partition-id.194 = u32[] partition-id() + add.4309 = u32[] add(get-tuple-element.592, partition-id.194) + constant.11431 = u32[] constant(8) + remainder.194 = u32[] remainder(add.4309, constant.11431) + dynamic-slice.388 = s32[1]{0} dynamic-slice(constant.11432, remainder.194), dynamic_slice_sizes={1} + reshape.12959 = s32[] reshape(dynamic-slice.388) + constant.11433 = s32[] constant(0) + dynamic-slice.389 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12959, constant.11433), dynamic_slice_sizes={512,6288} + dot.244 = bf16[16,2048,6288]{2,1,0} dot(get-tuple-element.588, dynamic-slice.389), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4310 = bf16[16,2048,6288]{2,1,0} add(get-tuple-element.590, dot.244) + constant.11434 = u32[] constant(1) + add.4312 = u32[] add(get-tuple-element.592, constant.11434) + add.4313 = u32[] add(add.4312, partition-id.194) + remainder.195 = u32[] remainder(add.4313, constant.11431) + dynamic-slice.390 = s32[1]{0} dynamic-slice(constant.11432, remainder.195), dynamic_slice_sizes={1} + reshape.12960 = s32[] reshape(dynamic-slice.390) + dynamic-slice.391 = bf16[512,6288]{1,0} dynamic-slice(get-tuple-element.589, reshape.12960, constant.11433), dynamic_slice_sizes={512,6288} + dot.245 = bf16[16,2048,6288]{2,1,0} dot(collective-permute.194, dynamic-slice.391), lhs_contracting_dims={2}, rhs_contracting_dims={0} + add.4314 = bf16[16,2048,6288]{2,1,0} add(add.4310, dot.245) + get-tuple-element.591 = bf16[16,2048,6288]{2,1,0} get-tuple-element(param.195), index=3 + add.4315 = u32[] add(add.4312, constant.11434) + ROOT tuple.98 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(collective-permute.195, get-tuple-element.589, add.4314, get-tuple-element.591, add.4315) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + param = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) + get-tuple-element = u32[] get-tuple-element(param), index=4 + constant = u32[] constant(4) + ROOT compare = pred[] compare(get-tuple-element, constant), direction=LT +} + +ENTRY main.12_spmd { + param.4 = bf16[16,2048,512]{2,1,0} parameter(0) + param.5 = bf16[4096,6288]{1,0} parameter(1) + constant.22 = bf16[] constant(0) + broadcast = bf16[16,2048,6288]{2,1,0} broadcast(constant.22), dimensions={} + constant.24 = u32[] constant(0) + tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24) + while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 + all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true + param.6 = bf16[16,2048,6288]{2,1,0} parameter(2) + ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + GpuWindowedEinsumHandler gpu_handler; + bool changed; + TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); + EXPECT_TRUE(changed); + + HloInstruction* ag_loop = + FindInstructionByName(module->entry_computation(), "while"); + HloInstruction* inst = + FindInstructionByName(module->entry_computation(), "dot.7"); + // dot.7 should now consume output of the windowed einsum while loop. + EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(inst->operand(0)->tuple_index(), 5); + EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); +} } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.cc b/third_party/xla/xla/stream_executor/host/host_kernel.cc index 04586b5272432b..cad37e1bfa4fb0 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel.cc @@ -67,8 +67,7 @@ class HostKernelExecuteState : public tsl::ReferenceCounted { public: HostKernelExecuteState(HostKernel::TaskRunner task_runner, - HostKernel::KernelFunction* function, - ThreadDim thread_dims, + SE_HOST_Kernel* kernel, ThreadDim thread_dims, absl::Span args); // Notify of a completion of a host kernel task. @@ -112,6 +111,7 @@ HostKernel::HostKernel(std::shared_ptr thread_pool) HostKernel::HostKernel(unsigned arity, SE_HOST_Kernel* kernel, std::shared_ptr thread_pool) : function_(std::make_unique(kernel)), + kernel_(function_->kernel()), arity_(arity), thread_pool_(thread_pool) {} @@ -130,8 +130,6 @@ absl::Status HostKernel::Launch( thread_dims.z, }; - SE_HOST_Kernel* kernel = function_->kernel(); - for (uint64_t z = 0; z < thread_dims.z; ++z) { for (uint64_t y = 0; y < thread_dims.y; ++y) { for (uint64_t x = 0; x < thread_dims.x; ++x) { @@ -140,7 +138,7 @@ absl::Status HostKernel::Launch( SE_HOST_KernelCallFrame call_frame = { &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; - SE_HOST_KernelError* error = (*kernel)(&call_frame); + SE_HOST_KernelError* error = (*kernel_)(&call_frame); if (ABSL_PREDICT_FALSE(error != nullptr)) { return absl::InternalError("Failed to call host kernel"); @@ -174,8 +172,8 @@ tsl::AsyncValueRef HostKernel::Launch( } // Allocate a control structure that will orchestrate kernel execution. - auto state = tsl::MakeRef( - std::move(task_runner), function_.get(), thread_dims, args); + auto state = tsl::MakeRef(std::move(task_runner), + kernel_, thread_dims, args); state->CallAsync(/*start_index=*/0, /*end_index=*/num_tasks); @@ -183,11 +181,11 @@ tsl::AsyncValueRef HostKernel::Launch( } HostKernelExecuteState::HostKernelExecuteState( - HostKernel::TaskRunner task_runner, HostKernel::KernelFunction* function, + HostKernel::TaskRunner task_runner, SE_HOST_Kernel kernel, ThreadDim thread_dims, absl::Span args) : task_runner_(std::move(task_runner)), num_tasks_(thread_dims.x * thread_dims.y * thread_dims.z), - kernel_(function->kernel()), + kernel_(kernel), thread_dims_({thread_dims.x, thread_dims.y, thread_dims.z}), args_(args.begin(), args.end()), abort_(false), diff --git a/third_party/xla/xla/stream_executor/host/host_kernel.h b/third_party/xla/xla/stream_executor/host/host_kernel.h index 9d278b2b79c357..9bc96cb9e7ca2a 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel.h +++ b/third_party/xla/xla/stream_executor/host/host_kernel.h @@ -113,10 +113,12 @@ class HostKernel : public Kernel { std::enable_if_t>* = nullptr> void SetKernelFunction(std::unique_ptr function) { function_ = std::move(function); + kernel_ = function_->kernel(); } private: std::unique_ptr function_; + SE_HOST_Kernel* kernel_; // pointer to the kernel owned by `function_` unsigned arity_; std::shared_ptr thread_pool_;