From 57984a0ad700aeae2a3438d4a2a1d09a3ff4bf53 Mon Sep 17 00:00:00 2001 From: Arturo Schmidt Date: Tue, 26 Mar 2024 18:13:49 -0700 Subject: [PATCH] Rename mlir_import_options to saved_model_import_options. PiperOrigin-RevId: 619370979 --- tensorflow/BUILD | 2 +- tensorflow/compiler/mlir/BUILD | 1 + tensorflow/compiler/mlir/lite/BUILD | 2 +- .../compiler/mlir/lite/tf_to_tfl_flatbuffer.cc | 4 ++-- tensorflow/compiler/mlir/python/BUILD | 1 + tensorflow/compiler/mlir/python/mlir.cc | 5 +++-- .../mlir/quantization/stablehlo/cc/BUILD | 4 ++-- .../stablehlo/cc/saved_model_import.cc | 6 +++--- .../mlir/quantization/tensorflow/python/BUILD | 4 ++-- .../tensorflow/python/quantize_model.cc | 4 ++-- tensorflow/compiler/mlir/tensorflow/BUILD | 2 +- .../compiler/mlir/tensorflow/translate/BUILD | 8 ++++---- .../mlir/tensorflow/translate/import_model.cc | 18 +++++++++--------- .../mlir/tensorflow/translate/import_model.h | 9 +++++---- ..._options.h => saved_model_import_options.h} | 12 ++++-------- .../tensorflow/translate/tf_mlir_translate.cc | 7 ++++--- .../tensorflow/translate/tf_mlir_translate.h | 6 +++--- .../compiler/mlir/tf_mlir_translate_main.cc | 5 +++-- 18 files changed, 51 insertions(+), 49 deletions(-) rename tensorflow/compiler/mlir/tensorflow/translate/{mlir_import_options.h => saved_model_import_options.h} (80%) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 71487e2aec0bee..ba2ddeaafa420d 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1389,7 +1389,7 @@ tf_cc_shared_library( "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:export_graphdef", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:saved_model_import_options", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "@local_xla//xla/service:computation_placer", diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index a94e356a2aa48d..031f65d335fd07 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -229,6 +229,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_registration", "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", + "//tensorflow/compiler/mlir/tensorflow/translate:saved_model_import_options", "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 231fc394598331..72fd1aa8c84738 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1445,8 +1445,8 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:saved_model_import_options", "//tensorflow/compiler/mlir/tensorflow:tf_dialect_lib", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index dd8b345862e3c8..786e816d4bb269 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -68,8 +68,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/op.h" @@ -551,7 +551,7 @@ absl::StatusOr> ImportSavedModel( if (!module_or.status().ok()) return module_or.status(); return std::move(module_or).value(); } else if (saved_model_version == 1) { - MLIRImportOptions options; + SavedModelImportOptions options; options.upgrade_legacy = specs.upgrade_legacy; options.unconditionally_use_set_output_shapes = true; options.lift_variables = enable_variable_lifting; diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index afc088517dc35f..5627cf3c65b5f6 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -46,6 +46,7 @@ cc_library( "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:tfe_context_internal", "@local_xla//xla/mlir/framework/transforms:passes", + "//tensorflow/compiler/mlir/tensorflow/translate:saved_model_import_options", "@local_xla//xla/mlir_hlo:all_passes", "//tensorflow/compiler/mlir/lite:flatbuffer_import", "//tensorflow/compiler/mlir/lite:tensorflow_lite", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index de1226c68c39d0..9409f229b2b400 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -55,6 +55,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" @@ -275,7 +276,7 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite( mlir::func::registerAllExtensions(registry); mlir::MLIRContext context(registry); - tensorflow::MLIRImportOptions import_options; + tensorflow::SavedModelImportOptions import_options; import_options.upgrade_legacy = upgrade_legacy; auto module_or = SavedModelSignatureDefsToMlirImportLite( saved_model_path, tag_set, absl::Span(exported_names), @@ -312,7 +313,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); mlir::MLIRContext context(registry); - tensorflow::MLIRImportOptions import_options; + tensorflow::SavedModelImportOptions import_options; import_options.upgrade_legacy = upgrade_legacy; import_options.lift_variables = lift_variables; import_options.include_variables_in_initializers = diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index f175dfdc9ea1a2..baa289f4d49201 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -217,7 +217,7 @@ cc_library( "//tensorflow/cc/saved_model:reader", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:saved_model_import_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/algorithm:container", @@ -417,7 +417,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/quantization/tensorflow/python:unfreeze_constants", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:saved_model_import_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc index 295ab06eb1bf70..c6478af3ef1e47 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_import.cc @@ -39,7 +39,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tsl/platform/errors.h" @@ -48,8 +48,8 @@ limitations under the License. namespace mlir::quant::stablehlo { using ::stablehlo::quantization::QuantizationConfig; -using ::tensorflow::MLIRImportOptions; using ::tensorflow::SavedModelBundle; +using ::tensorflow::SavedModelImportOptions; using ::tensorflow::SavedModelSignatureDefsToMlirImport; using ::tensorflow::quantization::PreprocessAndFreezeGraph; @@ -58,7 +58,7 @@ absl::StatusOr SavedModelToMlirModuleOp( const std::unordered_set& tags, const std::vector& signature_keys, MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { - MLIRImportOptions import_options; + SavedModelImportOptions import_options; import_options.upgrade_legacy = true; import_options.lift_variables = false; import_options.include_variables_in_initializers = true; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 6630119fbaaac1..4f8ea6780409ac 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -56,8 +56,8 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op", # Required for CustomAggregator op registration. "//tensorflow/compiler/mlir/quantization/tensorflow/cc:convert_asset_args", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", - "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:dump_tensor_op", # Required for DumpTensor op registration. - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:dump_tensor_op", + "//tensorflow/compiler/mlir/tensorflow:saved_model_import_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index e055d9dda95bbb..d2f85ca7cd8b33 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -56,7 +56,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -98,7 +98,7 @@ absl::StatusOr> ImportAndPreprocessSavedModel( const bool deserialize_xla_call_module, absl::flat_hash_map &function_aliases) { // Convert the SavedModelBundle to an MLIR module. - MLIRImportOptions import_options; + SavedModelImportOptions import_options; import_options.upgrade_legacy = true; import_options.lift_variables = false; import_options.include_variables_in_initializers = true; diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 7fda5b22e08ca2..eb526397b74802 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1668,7 +1668,7 @@ aliased_targets = [ "mlir_roundtrip_pass", "mlir_roundtrip_pass_registration", "mlir_roundtrip_flags", - "mlir_import_options", + "saved_model_import_options", "translate_lib", "translate_cl_options", "translate_registration", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index 6e3a8377b3d81e..39363980e7e570 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -34,7 +34,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:mangling_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:saved_model_import_options", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:translate_utils", @@ -175,8 +175,8 @@ cc_library( ) cc_library( - name = "mlir_import_options", - hdrs = ["mlir_import_options.h"], + name = "saved_model_import_options", + hdrs = ["saved_model_import_options.h"], visibility = ["//visibility:public"], ) @@ -194,7 +194,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_utils", "//tensorflow/compiler/mlir/tensorflow:mangling_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:saved_model_import_options", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3e72550a88749a..7b5b940eac24f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -84,8 +84,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" @@ -2791,7 +2791,7 @@ class SavedModelObjectGraphImporter : public ImporterBase { // Module. static absl::StatusOr> Convert( SavedModelV2Bundle* saved_model, absl::Span exported_names, - mlir::MLIRContext* context, MLIRImportOptions options); + mlir::MLIRContext* context, SavedModelImportOptions options); private: explicit SavedModelObjectGraphImporter( @@ -3345,7 +3345,7 @@ Status CreateSavedModelIR( const ObjectNames& object_names, mlir::ModuleOp module, const SavedObjectGraph& object_graph, const std::unordered_map& tf_name_to_mlir_name, - SavedModelV2Bundle* saved_model, MLIRImportOptions import_options) { + SavedModelV2Bundle* saved_model, SavedModelImportOptions import_options) { mlir::OpBuilder builder(module.getBodyRegion()); mlir::SymbolTable symbol_table(module); @@ -3562,7 +3562,7 @@ absl::StatusOr> SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, absl::Span exported_names, mlir::MLIRContext* context, - MLIRImportOptions import_options) { + SavedModelImportOptions import_options) { LoadImporterDialects(*context); GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = @@ -3641,7 +3641,7 @@ SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, class SimpleSavedModelMLIRImportInput : public SavedModelMLIRImportInput { public: static absl::StatusOr Create( - const MLIRImportOptions& import_options, + const SavedModelImportOptions& import_options, const MetaGraphDef* meta_graph_def, const GraphDebugInfo& debug_info) { DCHECK(meta_graph_def); GraphDef graph_def(meta_graph_def->graph_def()); @@ -4183,7 +4183,7 @@ class SavedModelSignatureDefImporter { static absl::StatusOr> Convert( const SavedModelBundle& bundle, std::optional> exported_names, - mlir::MLIRContext* context, tensorflow::MLIRImportOptions options) { + mlir::MLIRContext* context, tensorflow::SavedModelImportOptions options) { // debug_info might not be loaded with loader_lite. GraphDebugInfo debug_info; if (bundle.debug_info != nullptr) debug_info = *bundle.debug_info; @@ -4348,14 +4348,14 @@ absl::StatusOr> ConvertFunctionToMlir( absl::StatusOr> ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, MLIRImportOptions options) { + absl::Span exported_names, SavedModelImportOptions options) { return SavedModelObjectGraphImporter::Convert(saved_model, exported_names, context, options); } absl::StatusOr> ConvertSavedModelV1ToMlir( const SavedModelBundle& saved_model, absl::Span exported_names, - mlir::MLIRContext* context, MLIRImportOptions options) { + mlir::MLIRContext* context, SavedModelImportOptions options) { std::optional> optional_exported_names; // TODO(b/187062560): Change ConvertSavedModelV1ToMlir() to take an optional // `exported_names` so that it can be configured to import only restore/init @@ -4368,7 +4368,7 @@ absl::StatusOr> ConvertSavedModelV1ToMlir( absl::StatusOr> ConvertSavedModelV1ToMlirLite( const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, std::optional> exported_names, - mlir::MLIRContext* context, MLIRImportOptions options) { + mlir::MLIRContext* context, SavedModelImportOptions options) { TF_ASSIGN_OR_RETURN(auto input, SimpleSavedModelMLIRImportInput::Create( options, &meta_graph_def, debug_info)); return ConvertSavedModelV1ToMlirLite( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index bca1f7f80af9e8..07d3ed0caa1b03 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -26,8 +26,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/loader.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" @@ -61,13 +61,14 @@ absl::StatusOr> ConvertFunctionToMlir( // with tf_executor dialect. absl::StatusOr> ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, - absl::Span exported_names, MLIRImportOptions options = {}); + absl::Span exported_names, + SavedModelImportOptions options = {}); // Given a V1 SavedModel, returns a MLIR module containing the functions, // expressed with tf_executor dialect. absl::StatusOr> ConvertSavedModelV1ToMlir( const SavedModelBundle& saved_model, absl::Span exported_names, - mlir::MLIRContext* context, MLIRImportOptions options = {}); + mlir::MLIRContext* context, SavedModelImportOptions options = {}); // Given a V1 SavedModel, returns a MLIR module containing the functions, // expressed with tf_executor dialect. It does not require a session to be @@ -82,7 +83,7 @@ absl::StatusOr> ConvertSavedModelV1ToMlir( absl::StatusOr> ConvertSavedModelV1ToMlirLite( const MetaGraphDef& meta_graph_def, const GraphDebugInfo& debug_info, std::optional> exported_names, - mlir::MLIRContext* context, MLIRImportOptions options); + mlir::MLIRContext* context, SavedModelImportOptions options); // SavedModelMLIRImportInput is an adapter class for users to inject custom // graph transformation logic on Tensorflow graphs before importing to MLIR. It diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h b/tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h similarity index 80% rename from tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h rename to tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h index 44262d0bd08d86..68ffd86515f893 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h @@ -13,16 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SAVED_MODEL_IMPORT_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SAVED_MODEL_IMPORT_OPTIONS_H_ namespace tensorflow { -// TODO(jpienaar): This file and class are confusingly named. This seems to be -// a SavedModel only import options file that exposes a subset of the -// GraphImportConfig options, but the naming would make one think it is more -// general. -struct MLIRImportOptions { +struct SavedModelImportOptions { // If true, functionalize the input graph before importing it into MLIR. bool upgrade_legacy = false; @@ -53,4 +49,4 @@ struct MLIRImportOptions { } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_MLIR_IMPORT_OPTIONS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_SAVED_MODEL_IMPORT_OPTIONS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 6eaa15e37d45e1..29a87995c493ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" @@ -165,7 +166,7 @@ SavedModelObjectGraphToMlirImport(absl::string_view saved_model_dir, return load_status; } - MLIRImportOptions options; + SavedModelImportOptions options; options.add_default_attributes = true; options.unconditionally_use_set_output_shapes = unconditionally_use_set_output_shapes; @@ -183,7 +184,7 @@ SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, - MLIRImportOptions options, + SavedModelImportOptions options, std::unique_ptr* saved_model_bundle) { // Create local bundle if no one is provided to use. std::unique_ptr bundle; @@ -220,7 +221,7 @@ SavedModelSignatureDefsToMlirImportLite( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, - MLIRImportOptions options) { + SavedModelImportOptions options) { MetaGraphDef meta_graph_def; auto status = ReadMetaGraphDefFromSavedModel(saved_model_dir, tags, &meta_graph_def); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h index cd86b27e13550c..c4bdef05dfd627 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h @@ -28,7 +28,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" namespace tensorflow { @@ -116,7 +116,7 @@ SavedModelSignatureDefsToMlirImport( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, - MLIRImportOptions options, + SavedModelImportOptions options, std::unique_ptr* saved_model_bundle = nullptr); @@ -129,7 +129,7 @@ SavedModelSignatureDefsToMlirImportLite( absl::string_view saved_model_dir, const std::unordered_set& tags, absl::Span exported_names, mlir::MLIRContext* context, - MLIRImportOptions options); + SavedModelImportOptions options); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index 00ae360f1e697e..3634089145492d 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Support/ToolUtilities.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/saved_model_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" #include "tensorflow/core/platform/init_main.h" @@ -131,7 +132,7 @@ int main(int argc, char** argv) { module_or.value()->print(output->os()); } else if (import_saved_model_signature_defs) { mlir::MLIRContext context; - tensorflow::MLIRImportOptions import_options; + tensorflow::SavedModelImportOptions import_options; import_options.upgrade_legacy = upgrade_legacy; auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( input_filename, tags, exported_names, &context, import_options); @@ -140,7 +141,7 @@ int main(int argc, char** argv) { module_or.value()->print(output->os()); } else if (import_saved_model_signature_defs_lite) { mlir::MLIRContext context; - tensorflow::MLIRImportOptions import_options; + tensorflow::SavedModelImportOptions import_options; import_options.upgrade_legacy = upgrade_legacy; auto module_or = tensorflow::SavedModelSignatureDefsToMlirImportLite( input_filename, tags, exported_names, &context, import_options);