[go: nahoru, domu]

Skip to content

Commit

Permalink
Always return error status if serialization to FlatBuffer failed.
Browse files Browse the repository at this point in the history
An error is not emitted to the DiagnosticHandler in all cases in which
serialization can fail. This change bases failure on the result of the
`MlirToFlatBufferTranslateFunction` method call rather than the value of the
status handler's status.

PiperOrigin-RevId: 559811886
  • Loading branch information
arfaian authored and tensorflower-gardener committed Aug 24, 2023
1 parent 4b34b37 commit d1a5c9a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,8 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/lite/tools/optimize:reduced_precision_support",
"//tensorflow/tsl/platform:statusor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/compiler/mlir/lite/tf_tfl_translate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,10 @@ int main(int argc, char **argv) {
auto status = tensorflow::ConvertTFExecutorToTFLOrFlatbuffer(
module.value().get(), output_mlir, toco_flags, pass_config, tags,
/*saved_model_dir=*/"", session, &result, serialize_stablehlo_ops);
if (!status.ok()) return kTrFailure;
if (!status.ok()) {
llvm::errs() << status.message() << '\n';
return kTrFailure;
}

std::string error_msg;
auto output = mlir::openOutputFile(output_file_name, &error_msg);
Expand Down
16 changes: 14 additions & 2 deletions tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -296,7 +298,12 @@ Status ConvertTFExecutorToStablehloFlatbuffer(
options.op_or_arg_name_mapper = &op_or_arg_name_mapper;
options.metadata[tflite::kModelUseStablehloTensorKey] = "true";
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) {
return statusHandler.ConsumeStatus();
auto s = statusHandler.ConsumeStatus();
std::string message = "Could not translate MLIR to FlatBuffer.";
if (!s.ok()) {
absl::StrAppend(&message, " ", s.ToString());
}
return absl::UnknownError(message);
}

return OkStatus();
Expand Down Expand Up @@ -425,7 +432,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
}
if (!tflite::MlirToFlatBufferTranslateFunction(
module, options, &translated_result, serialize_stablehlo_ops)) {
return statusHandler.ConsumeStatus();
auto s = statusHandler.ConsumeStatus();
std::string message = "Could not translate MLIR to FlatBuffer.";
if (!s.ok()) {
absl::StrAppend(&message, " ", s.ToString());
}
return absl::UnknownError(message);
}

// TODO(b/176267167): Quantize flex fallback in the MLIR pipeline
Expand Down

0 comments on commit d1a5c9a

Please sign in to comment.