[go: nahoru, domu]

Skip to content

Commit

Permalink
Automatically detect whether the input model contains StableHLO.
Browse files Browse the repository at this point in the history
Removes the need for special flags to handle models generated from JAX with
`native_serialization=True`. Also removes the
`_experimental_enable_hlo_to_tf_conversion` option from the Converter API.

PiperOrigin-RevId: 547033526
  • Loading branch information
arfaian authored and tensorflower-gardener committed Jul 11, 2023
1 parent a622eb3 commit 36fe2f9
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 17 deletions.
3 changes: 1 addition & 2 deletions tensorflow/compiler/mlir/lite/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tf_tfl_passes",
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
"//tensorflow/compiler/mlir/lite/metrics:error_collector_inst",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:import_model",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ Status ConvertGraphDefToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
pass_config.guarantee_all_funcs_one_use =
toco_flags.guarantee_all_funcs_one_use();
pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo();
pass_config.enable_hlo_to_tf_conversion =
toco_flags.enable_hlo_to_tf_conversion();

return internal::ConvertMLIRToTFLiteFlatBuffer(
model_flags, toco_flags, std::move(module), pass_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,6 @@ Status ConvertSavedModelToTFLiteFlatBuffer(const toco::ModelFlags& model_flags,
pass_config.guarantee_all_funcs_one_use =
toco_flags.guarantee_all_funcs_one_use();
pass_config.enable_stablehlo_conversion = toco_flags.convert_to_stablehlo();
pass_config.enable_hlo_to_tf_conversion =
toco_flags.enable_hlo_to_tf_conversion();
pass_config.legalize_custom_tensor_list_ops =
toco_flags.legalize_custom_tensor_list_ops();

Expand Down
13 changes: 13 additions & 0 deletions tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
#include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
Expand Down Expand Up @@ -346,6 +347,18 @@ Status ConvertMLIRToTFLiteFlatBuffer(

mlir::TFL::PassConfig pass_config_copy = pass_config;
pass_config_copy.outline_tf_while = true;

// Checks whether the model contains an `XlaCallModuleOp` operation which
// is a wrapper around StableHLO.
// This option is mutually exclusive to `enable_stablehlo_conversion`, the
// latter of which takes precedence.
// TODO(b/290109282): explore removing the enable_hlo_to_tf_conversion flag
// entirely, such that the added passes are no-ops in the non-shlo case.
module->walk([&](mlir::TF::XlaCallModuleOp xla_call_module_op) {
pass_config_copy.enable_hlo_to_tf_conversion = true;
mlir::WalkResult::interrupt();
});

auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy,
saved_model_tags, model_flags.saved_model_dir(), session, result);
Expand Down
6 changes: 0 additions & 6 deletions tensorflow/lite/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,6 @@ def build_conversion_flags(
enable_mlir_variable_quantization=False,
disable_fuse_mul_and_fc=False,
quantization_options: Optional[quant_opts_pb2.QuantizationOptions] = None,
enable_hlo_to_tf_conversion=False,
mlir_dump_dir=None,
mlir_dump_pass_regex=None,
mlir_dump_func_regex=None,
Expand Down Expand Up @@ -686,9 +685,6 @@ def build_conversion_flags(
a custom method, and allows finer, modular control. This option will
override any other existing quantization flags. We plan on gradually
migrating all quantization-related specs into this option.
enable_hlo_to_tf_conversion: Enable HLO to TF conversion in the Converter.
Set this to False by default as this may increase the conversion time if
set otherwise.
mlir_dump_dir: A string specifying the target directory to output MLIR dumps
produced during conversion. If populated, enables MLIR dumps.
mlir_dump_pass_regex: A string containing a regular expression for filtering
Expand Down Expand Up @@ -797,8 +793,6 @@ def build_conversion_flags(
if quantization_options:
conversion_flags.quantization_options.CopyFrom(quantization_options)

conversion_flags.enable_hlo_to_tf_conversion = enable_hlo_to_tf_conversion

# Transfer debug options. Check for existence before populating in order to
# leverage defaults specified in proto definition.
if mlir_dump_dir is not None:
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/lite/python/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ def __init__(self):
self._experimental_enable_dynamic_update_slice = False
self._experimental_preserve_assert_op = False
self._experimental_guarantee_all_funcs_one_use = False
self._experimental_enable_hlo_to_tf_conversion = False

# When the value is true, the MLIR quantantizer triggers dynamic range
# quantization in MLIR instead of the old quantizer. Used only if
Expand Down Expand Up @@ -790,9 +789,6 @@ def _get_base_converter_args(self):
"allow_all_select_tf_ops": self._experimental_allow_all_select_tf_ops,
"disable_fuse_mul_and_fc": self._experimental_disable_fuse_mul_and_fc,
"quantization_options": self._experimental_quantization_options,
"enable_hlo_to_tf_conversion": (
self._experimental_enable_hlo_to_tf_conversion
),
"mlir_dump_dir": self.mlir_dump_dir,
"mlir_dump_pass_regex": self.mlir_dump_pass_regex,
"mlir_dump_func_regex": self.mlir_dump_func_regex,
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/lite/toco/toco_flags.proto
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ message TocoFlags {

// Flag to enable hlo to tf conversion.
// This is useful to exercise StableHLO -> HLO -> TF -> TFLite path.
optional bool enable_hlo_to_tf_conversion = 55 [default = false];
optional bool enable_hlo_to_tf_conversion = 55
[default = false, deprecated = true];

// Additional parameters for controlling debug facilities.
optional tensorflow.converter.DebugOptions debug_options = 56;
Expand Down

0 comments on commit 36fe2f9

Please sign in to comment.