[go: nahoru, domu]

Skip to content

Commit

Permalink
Clean up DiagnosticHandler usage.
Browse files Browse the repository at this point in the history
In some cases, status messages were getting unintentionally dropped. These
changes reduce some of the chaos and allow proper propagation of all captured
diagnostic messages.

PiperOrigin-RevId: 601174716
  • Loading branch information
arfaian authored and tensorflower-gardener committed Jan 24, 2024
1 parent 6172674 commit accc669
Showing 1 changed file with 86 additions and 47 deletions.
133 changes: 86 additions & 47 deletions tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ Status RegisterExtraTfOpDefs(absl::Span<const std::string> extra_tf_opdefs) {
if (!tensorflow::protobuf::TextFormat::ParseFromString(tf_opdefs_string,
&opdef)) {
LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string;
return absl::InvalidArgumentError("fail to parse extra OpDef");
return errors::InvalidArgument("fail to parse extra OpDef");
}
// Register extra opdefs.
// TODO(b/133770952): Support shape functions.
Expand All @@ -156,8 +156,8 @@ StatusOr<OwningOpRef<ModuleOp>> LoadFromGraphdefOrMlirSource(
std::string error_message;
auto file = mlir::openInputFile(input_filename, &error_message);
if (!file) {
return absl::InvalidArgumentError(
absl::StrCat("Failed to open input file: ", error_message));
llvm::errs() << error_message << "\n";
return errors::InvalidArgument("fail to open input file");
}

if (input_mlir) {
Expand Down Expand Up @@ -213,16 +213,15 @@ Status ApplyDynamicRangeQuantizationFromOldQuantizer(
quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16;
break;
default:
return absl::InvalidArgumentError("Quantized type not supported");
return errors::InvalidArgument("Quantized type not supported");
break;
}

bool use_updated_hybrid_scheme = !quant_specs.disable_per_channel;
if (::tflite::optimize::QuantizeWeights(
&q_builder, input_model, quantized_type, use_updated_hybrid_scheme,
::tflite::optimize::QuantizerType::OLD_QUANTIZER) != kTfLiteOk) {
return absl::InvalidArgumentError(
"Quantize weights transformation failed.");
return errors::InvalidArgument("Quantize weights transformation failed.");
}
const uint8_t* q_buffer = q_builder.GetBufferPointer();
*result =
Expand All @@ -233,7 +232,7 @@ Status ApplyDynamicRangeQuantizationFromOldQuantizer(

Status ConvertTFExecutorToStablehloFlatbuffer(
mlir::PassManager& pass_manager, mlir::ModuleOp module, bool export_to_mlir,
mlir::StatusScopedDiagnosticHandler& status_handler,
mlir::StatusScopedDiagnosticHandler& statusHandler,
const toco::TocoFlags& toco_flags, const mlir::TFL::PassConfig& pass_config,
std::optional<tensorflow::Session*> session, std::string* result,
const std::unordered_set<std::string>& saved_model_tags) {
Expand All @@ -245,8 +244,7 @@ Status ConvertTFExecutorToStablehloFlatbuffer(
const auto status = tensorflow::quantization::PreprocessAndFreezeGraph(
module, module.getContext(), session);
if (!status.ok()) {
return status_handler.Combine(
absl::InternalError("Failed to preprocess & freeze TF graph."));
return errors::Aborted("Failed to preprocess & freeze TF graph");
}

// TODO(b/264218457): Refactor the component below once StableHLO Quantizer
Expand All @@ -269,7 +267,7 @@ Status ConvertTFExecutorToStablehloFlatbuffer(
pass_manager, quantization_options);
}
if (failed(pass_manager.run(module))) {
return status_handler.ConsumeStatus();
return statusHandler.ConsumeStatus();
}
}

Expand All @@ -285,13 +283,13 @@ Status ConvertTFExecutorToStablehloFlatbuffer(
pass_manager, toco_flags.quantization_options());
}
if (failed(pass_manager.run(module))) {
return status_handler.ConsumeStatus();
return statusHandler.ConsumeStatus();
}

if (export_to_mlir) {
llvm::raw_string_ostream os(*result);
module.print(os);
return status_handler.ConsumeStatus();
return statusHandler.ConsumeStatus();
}

// Write MLIR Stablehlo dialect into FlatBuffer
Expand All @@ -303,8 +301,12 @@ Status ConvertTFExecutorToStablehloFlatbuffer(
options.metadata[tflite::kModelUseStablehloTensorKey] = "true";
if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result,
true)) {
return status_handler.Combine(
absl::InternalError("Could not translate MLIR to FlatBuffer."));
auto s = statusHandler.ConsumeStatus();
std::string message = "Could not translate MLIR to FlatBuffer.";
if (!s.ok()) {
absl::StrAppend(&message, " ", s.ToString());
}
return absl::UnknownError(message);
}

return OkStatus();
Expand All @@ -324,57 +326,92 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
mlir::func::registerAllExtensions(registry);
module.getContext()->appendDialectRegistry(registry);

mlir::StatusScopedDiagnosticHandler status_handler(module.getContext(),
/*propagate=*/true);
// Register a warning handler only log to std out.
mlir::ScopedDiagnosticHandler s(
module.getContext(), [](mlir::Diagnostic& diag) {
if (diag.getSeverity() == mlir::DiagnosticSeverity::Warning) {
for (auto& note : diag.getNotes()) {
std::cout << note.str() << "\n";
LOG(WARNING) << note.str() << "\n";
}
}
return mlir::failure();
});

mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
/*propagate=*/true);

if (failed(IsValidGraph(module))) {
return statusHandler.ConsumeStatus();
}

mlir::PassManager pass_manager(module.getContext());
mlir::registerPassManagerCLOptions();
if (mlir::failed(mlir::applyPassManagerCLOptions(pass_manager))) {
return absl::InternalError("Failed to apply MLIR pass manager CL options.");
return absl::UnknownError("failed to apply MLIR pass manager CL options");
}
pass_manager.addInstrumentation(
std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
pass_manager.getContext()));
InitPassManager(pass_manager, toco_flags.debug_options());

if (failed(IsValidGraph(module))) {
return status_handler.ConsumeStatus();
}

if (pass_config.enable_stablehlo_conversion) {
// return to avoid adding TFL converter path
return ConvertTFExecutorToStablehloFlatbuffer(
pass_manager, module, export_to_mlir, status_handler, toco_flags,
pass_manager, module, export_to_mlir, statusHandler, toco_flags,
pass_config, session, result, saved_model_tags);
}

tensorflow::AddPreVariableFreezingTFToTFLConversionPasses(pass_config,
&pass_manager);
if (failed(pass_manager.run(module))) {
return status_handler.ConsumeStatus();
return statusHandler.ConsumeStatus();
}

// Freeze variables if a session is provided.
if (session.has_value() &&
failed(mlir::tf_saved_model::FreezeVariables(module, session.value()))) {
return status_handler.Combine(absl::InvalidArgumentError(
"Variable constant folding is failed. Please consider using "
"enabling `experimental_enable_resource_variables` flag in the "
"TFLite converter object. For example, "
"converter.experimental_enable_resource_variables = True"));
if (session.has_value()) {
mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext());
if (failed(
mlir::tf_saved_model::FreezeVariables(module, session.value()))) {
auto status = statusHandler.ConsumeStatus();
mlir::TFL::ErrorCollector* collector =
mlir::TFL::ErrorCollector::GetErrorCollector();
if (!collector->CollectedErrors().empty()) {
// LINT.IfChange
return errors::InvalidArgument(
"Variable constant folding is failed. Please consider using "
"enabling `experimental_enable_resource_variables` flag in the "
"TFLite converter object. For example, "
"converter.experimental_enable_resource_variables = True");
// LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
}
return status;
}
}

pass_manager.clear();

tensorflow::AddPostVariableFreezingTFToTFLConversionPasses(
saved_model_dir, toco_flags, pass_config, &pass_manager);
if (failed(pass_manager.run(module))) {
return status_handler.Combine(absl::InvalidArgumentError(
"Variable constant folding is failed. Please consider using "
"enabling `experimental_enable_resource_variables` flag in the "
"TFLite converter object. For example, "
"converter.experimental_enable_resource_variables = True"));
auto status = statusHandler.ConsumeStatus();
mlir::TFL::ErrorCollector* collector =
mlir::TFL::ErrorCollector::GetErrorCollector();
for (const auto& error_data : collector->CollectedErrors()) {
if (error_data.subcomponent() == "FreezeGlobalTensorsPass") {
// LINT.IfChange
return errors::InvalidArgument(
"Variable constant folding is failed. Please consider using "
"enabling `experimental_enable_resource_variables` flag in the "
"TFLite converter object. For example, "
"converter.experimental_enable_resource_variables = True");
// LINT.ThenChange(//tensorflow/lite/python/lite_v2_test.py)
}
}
return status;
}

if (failed(GraphContainsStatefulPartitionedOp(module))) {
return status_handler.ConsumeStatus();
return statusHandler.ConsumeStatus();
}

if (export_to_mlir) {
Expand All @@ -383,12 +420,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
pass_manager.addPass(mlir::odml::createPrintOpStatsPass(
mlir::odml::GetAcceptedTFLiteDialects()));
if (failed(pass_manager.run(module))) {
return status_handler.ConsumeStatus();
return statusHandler.ConsumeStatus();
}

llvm::raw_string_ostream os(*result);
module.print(os);
return status_handler.ConsumeStatus();
return statusHandler.ConsumeStatus();
}

// Write MLIR TFLite dialect into FlatBuffer
Expand All @@ -406,8 +443,12 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
}
if (!tflite::MlirToFlatBufferTranslateFunction(
module, options, &translated_result, serialize_stablehlo_ops)) {
return status_handler.Combine(
absl::InternalError("Could not translate MLIR to FlatBuffer."));
auto s = statusHandler.ConsumeStatus();
std::string message = "Could not translate MLIR to FlatBuffer.";
if (!s.ok()) {
absl::StrAppend(&message, " ", s.ToString());
}
return absl::UnknownError(message);
}

// TODO(b/176267167): Quantize flex fallback in the MLIR pipeline
Expand All @@ -419,16 +460,13 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
// statement.
auto status = ApplyDynamicRangeQuantizationFromOldQuantizer(
quant_specs, translated_result, result);
if (!status.ok()) {
return status_handler.Combine(status);
}
if (!status.ok()) return status;
} else {
*result = translated_result;
}

if (mlir::failed(module.verifyInvariants())) {
return status_handler.Combine(
absl::InternalError("Final module is invalid."));
return tensorflow::errors::Unknown("Final module is invalid");
}
return OkStatus();
}
Expand Down Expand Up @@ -462,7 +500,8 @@ StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> ImportSavedModel(
if (!module_or.status().ok()) return module_or.status();
return std::move(module_or).value();
} else {
return absl::InvalidArgumentError("Should be either saved model v1 or v2.");
return tensorflow::errors::InvalidArgument(
"Should be either saved model v1 or v2");
}
}

Expand Down

0 comments on commit accc669

Please sign in to comment.