From d1a5c9af980cdfb4da1abb5abf52590103f5b80a Mon Sep 17 00:00:00 2001 From: Arian Arfaian Date: Thu, 24 Aug 2023 11:15:32 -0700 Subject: [PATCH] Always return error status if serialization to FlatBuffer failed. An error is not emitted to the DiagnosticHandler in all cases in which serialization can fail. This change bases failure on the result of the `MlirToFlatBufferTranslateFunction` method call rather than the value of the status handler's status. PiperOrigin-RevId: 559811886 --- tensorflow/compiler/mlir/lite/BUILD | 2 ++ .../compiler/mlir/lite/tf_tfl_translate.cc | 5 ++++- .../compiler/mlir/lite/tf_to_tfl_flatbuffer.cc | 16 ++++++++++++++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 301e904511ba0d..90eb687c4de63e 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -1334,6 +1334,8 @@ cc_library( "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:reduced_precision_support", "//tensorflow/tsl/platform:statusor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index ffd9dbca0fac29..70793925485481 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -323,7 +323,10 @@ int main(int argc, char **argv) { auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer( module.value().get(), output_mlir, toco_flags, pass_config, tags, /*saved_model_dir=*/"", session, &result, serialize_stablehlo_ops); - if (!status.ok()) return kTrFailure; + if (!status.ok()) { + llvm::errs() << status.message() << '\n'; + return kTrFailure; + } std::string error_msg; auto output = mlir::openOutputFile(output_file_name, &error_msg); diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 3b93694f3286f0..59e5bbf724f592 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" @@ -296,7 +298,12 @@ Status ConvertTFExecutorToStablehloFlatbuffer( options.op_or_arg_name_mapper = &op_or_arg_name_mapper; options.metadata[tflite::kModelUseStablehloTensorKey] = "true"; if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) { - return statusHandler.ConsumeStatus(); + auto s = statusHandler.ConsumeStatus(); + std::string message = "Could not translate MLIR to FlatBuffer."; + if (!s.ok()) { + absl::StrAppend(&message, " ", s.ToString()); + } + return absl::UnknownError(message); } return OkStatus(); @@ -425,7 +432,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer( } if (!tflite::MlirToFlatBufferTranslateFunction( module, options, &translated_result, serialize_stablehlo_ops)) { - return statusHandler.ConsumeStatus(); + auto s = statusHandler.ConsumeStatus(); + std::string message = "Could not translate MLIR to FlatBuffer."; + if (!s.ok()) { + absl::StrAppend(&message, " ", s.ToString()); + } + return absl::UnknownError(message); } // TODO(b/176267167): Quantize flex fallback in the MLIR pipeline