[go: nahoru, domu]

Skip to content

Commit

Permalink
Clone the module into a new context to erase the unused tensors from …
Browse files Browse the repository at this point in the history
…memory.

Ideally, unused tensors cleaned-up during DCE must not consume any memory but they currently do. This is due to existing design flaws in MLIRContext and have been addressed with the introduction of DenseResourceElementsAttr in MLIR.

This is a temporary workaround until we are able to use the DenseResourceElementsAttr to store the model checkpoint weights.

PiperOrigin-RevId: 645520361
  • Loading branch information
vamsimanchala authored and tensorflower-gardener committed Jun 26, 2024
1 parent 3abf81f commit d5b1215
Show file tree
Hide file tree
Showing 14 changed files with 456 additions and 173 deletions.
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,7 @@ cc_library(
],
deps = [
":common",
":const_tensor_utils",
":flatbuffer_translate_lib",
":tensorflow_lite",
":tf_tfl_passes",
Expand Down Expand Up @@ -1449,6 +1450,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/ir/types:Dialect",
"//tensorflow/lite/python/metrics:converter_error_data_proto_cc",
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/tools/optimize:quantize_weights",
Expand All @@ -1460,14 +1462,22 @@ cc_library(
"@com_google_absl//absl/types:span",
"@flatbuffers//:runtime_cc",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BytecodeWriter",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FuncExtensions",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:ReconcileUnrealizedCasts",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:protobuf",
"@local_tsl//tsl/platform:statusor",
"@stablehlo//:stablehlo_ops",
"@stablehlo//:vhlo_ops",
],
)

Expand Down Expand Up @@ -1505,6 +1515,8 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/meta:type_traits",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/compiler/mlir/lite/debug/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ cc_library(
visibility = ["//tensorflow/compiler/mlir/lite:__subpackages__"],
deps = [
":debug_options_proto_cc",
"//tensorflow/compiler/mlir/lite/metrics:error_collector_inst",
"//tensorflow/core:portable_gif_internal",
"//tensorflow/lite:model_builder",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand All @@ -32,10 +35,11 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
"@local_tsl//tsl/lib/io:buffered_file",
"@local_tsl//tsl/platform:env",
"@local_tsl//tsl/platform:path",
"@local_tsl//tsl/platform:status",
"@local_tsl//tsl/platform:stringpiece",
],
)
Expand Down
87 changes: 86 additions & 1 deletion tensorflow/compiler/mlir/lite/debug/debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <stdint.h>

#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
Expand All @@ -32,24 +33,37 @@ limitations under the License.
#include "absl/time/time.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/FileUtilities.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/ViewOpGraph.h" // from @llvm-project
#include "re2/re2.h" // IWYU pragma: keep
#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"
#include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/model_builder.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tsl/lib/io/buffered_file.h"
#include "tsl/platform/env.h"
#include "tsl/platform/file_system.h"
#include "tsl/platform/path.h"
#include "tsl/platform/status.h"
#include "tsl/platform/stringpiece.h"

// IWYU pragma: no_include "util/regexp/re2/re2.h"

using mlir::func::FuncOp;

namespace tensorflow {
namespace {

Expand Down Expand Up @@ -279,6 +293,73 @@ std::function<bool(mlir::Pass*, mlir::Operation*)> CreatePrintIRFun(

} // namespace

absl::Status PrintFunctionResultMapping(const std::string& result,
mlir::ModuleOp module) {
// Build model from the resultant string to extract the return values from
// their source of truth.
auto model =
tflite::FlatBufferModel::BuildFromBuffer(result.data(), result.size());
if (!model) return absl::NotFoundError("Failed to build model from result");

// Get an unknown location for where we don't have a terminator to get the
// location of the return value from.
auto unknown_loc = mlir::UnknownLoc::get(module.getContext());

auto print_buffer = [&](const tflite::SubGraph& subgraph, int id, int buffer,
std::function<mlir::Location(int)> loc) {
const auto& output_tensor = (*subgraph.tensors())[buffer];
std::cout << "\tname: '"
<< (output_tensor->name() ? output_tensor->name()->str()
: "<<unnamed>>")
<< "' buffer: " << buffer;
if (loc) std::cout << llvm::formatv(" {0}", loc(id)).str();
std::cout << '\n';
};

// For every subgraph print out the name (if available), each result's output
// buffer number and location of the return value (if available).
for (auto* subgraph : *(*model)->subgraphs()) {
std::string subgraph_name =
subgraph->name() ? subgraph->name()->str() : "<<unnamed subgraph>>";

std::cout << '\'' << subgraph_name << "' inputs:\n";
int i = 0;
for (auto input : *subgraph->inputs())
print_buffer(*subgraph, i++, input, nullptr);

std::cout << '\'' << subgraph_name << "' outputs:\n";
mlir::Operation* terminator = nullptr;
if (subgraph->name()) {
if (auto fn = module.lookupSymbol<FuncOp>(subgraph->name()->str()))
terminator = fn.back().getTerminator();
}
i = 0;
for (auto output : *subgraph->outputs()) {
print_buffer(*subgraph, i, output, [&](int i) {
return terminator ? terminator->getOperand(i).getLoc() : unknown_loc;
});
}
}
return absl::OkStatus();
}

absl::Status DumpOpGraphToFile(mlir::ModuleOp module,
const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return absl::InternalError(
absl::StrCat("Failed to open file in ", filename));
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
if (failed(pm.run(module))) {
return absl::InternalError("Failed to dump Op Graph from MLIR module.");
}
output->keep();
return absl::OkStatus();
}

void InitPassManager(mlir::PassManager& pm,
const converter::DebugOptions& options,
llvm::raw_ostream& out) {
Expand Down Expand Up @@ -342,6 +423,10 @@ void InitPassManager(mlir::PassManager& pm,
if (options.enable_timing()) {
pm.enableTiming();
}

pm.addInstrumentation(
std::make_unique<mlir::TFL::ErrorCollectorInstrumentation>(
pm.getContext()));
}

} // namespace tensorflow
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/lite/debug/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_

#include <string>

#include "absl/status/status.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"

Expand All @@ -29,6 +33,14 @@ void InitPassManager(mlir::PassManager& pm,
const converter::DebugOptions& options,
llvm::raw_ostream& out = llvm::outs());

// Print function mapping in the flatbuffer.
absl::Status PrintFunctionResultMapping(const std::string& result,
mlir::ModuleOp module);

// Dumps the op graph of the `module` to `filename` in DOT format.
absl::Status DumpOpGraphToFile(mlir::ModuleOp module,
const std::string& filename);

} // namespace tensorflow

#endif // TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h"

#include <memory>
#include <optional>
#include <string>
#include <utility>
Expand Down Expand Up @@ -44,7 +45,7 @@ absl::Status ConvertGraphDefToTFLiteFlatBuffer(
const GraphDebugInfo& debug_info, const GraphDef& input,
std::string* result) {
using ::tflite::optimize::ReducedPrecisionSupport;
mlir::MLIRContext context;
auto context = std::make_unique<mlir::MLIRContext>();
GraphImportConfig specs;
mlir::quant::QuantizationSpecs quant_specs;

Expand Down Expand Up @@ -85,8 +86,8 @@ absl::Status ConvertGraphDefToTFLiteFlatBuffer(
// Register all custom ops, including user-specified custom ops.
TF_RETURN_IF_ERROR(internal::RegisterAllCustomOps(toco_flags));

TF_ASSIGN_OR_RETURN(
auto module, ConvertGraphdefToMlir(input, debug_info, specs, &context));
TF_ASSIGN_OR_RETURN(auto module, ConvertGraphdefToMlir(input, debug_info,
specs, context.get()));

mlir::TFL::PassConfig pass_config(quant_specs);
bool emit_builtin_tflite_ops = !toco_flags.force_select_tf_ops();
Expand All @@ -112,7 +113,8 @@ absl::Status ConvertGraphDefToTFLiteFlatBuffer(
// StableHLO Quantizer is not supported for GraphDef inputs, so
// quantization_py_function_lib is set to nullptr.
return internal::ConvertMLIRToTFLiteFlatBuffer(
model_flags, toco_flags, std::move(module), pass_config,
model_flags, toco_flags, std::move(context), std::move(module),
pass_config,
/*saved_model_tags=*/{}, result, /*saved_model_bundle=*/nullptr,
/*quantization_py_function_lib=*/nullptr);
}
Expand Down
9 changes: 5 additions & 4 deletions tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ absl::Status ConvertJaxToTFLiteFlatBuffer(const std::string& input,
const toco::ModelFlags& model_flags,
toco::TocoFlags& toco_flags,
std::string* result) {
mlir::MLIRContext context;
auto context = std::make_unique<mlir::MLIRContext>();
mlir::quant::QuantizationSpecs quant_specs;

// Parse input arrays.
Expand Down Expand Up @@ -162,9 +162,9 @@ absl::Status ConvertJaxToTFLiteFlatBuffer(const std::string& input,

mlir::OwningOpRef<mlir::ModuleOp> module;
if (model_flags.hlo_file_type() == toco::ModelFlags::HLO_TEXT) {
module = HloTextToMlirHloTranslateFunction(input, &context, false);
module = HloTextToMlirHloTranslateFunction(input, context.get(), false);
} else if (model_flags.hlo_file_type() == toco::ModelFlags::HLO_PROTO) {
module = HloToMlirHloTranslateFunction(input, &context, false);
module = HloToMlirHloTranslateFunction(input, context.get(), false);
} else {
return errors::InvalidArgument("unknown hlo format type.");
}
Expand All @@ -191,7 +191,8 @@ absl::Status ConvertJaxToTFLiteFlatBuffer(const std::string& input,
// StableHLO Quantizer is not supported for JAX input models, so
// quantization_py_function_lib is set to nullptr.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
model_flags, toco_flags, std::move(module), pass_config,
model_flags, toco_flags, std::move(context), std::move(module),
pass_config,
/*saved_model_tags=*/{}, result, /*saved_model_bundle=*/nullptr,
/*quantization_py_function_lib=*/nullptr);
return status;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, toco::TocoFlags& toco_flags,
std::string* result,
const PyFunctionLibrary* quantization_py_function_lib) {
mlir::MLIRContext context;
auto context = std::make_unique<mlir::MLIRContext>();
mlir::quant::QuantizationSpecs quant_specs;

// Parse input arrays.
Expand Down Expand Up @@ -177,10 +177,11 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
auto bundle = std::make_unique<tensorflow::SavedModelBundle>();
TF_ASSIGN_OR_RETURN(
auto module,
ImportSavedModel(
model_flags.saved_model_dir(), model_flags.saved_model_version(),
tags, absl::MakeSpan(custom_opdefs), exported_names, specs,
!toco_flags.enable_tflite_resource_variables(), &context, &bundle));
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,
absl::MakeSpan(custom_opdefs), exported_names, specs,
!toco_flags.enable_tflite_resource_variables(),
context.get(), &bundle));

if (!model_flags.input_arrays().empty() ||
!model_flags.output_arrays().empty()) {
Expand Down Expand Up @@ -240,8 +241,10 @@ Status ConvertSavedModelToTFLiteFlatBuffer(

// TODO(b/153507667): Pass the session object when importing logic is removed.
auto status = internal::ConvertMLIRToTFLiteFlatBuffer(
model_flags, toco_flags, std::move(module), pass_config, tags, result,
std::move(bundle), quantization_py_function_lib);
model_flags, toco_flags, std::move(context), std::move(module),
pass_config, tags, result, std::move(bundle),
quantization_py_function_lib);

return status;
}

Expand Down
38 changes: 5 additions & 33 deletions tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,37 +328,14 @@ absl::Status PopulateQuantizationSpecs(
return absl::OkStatus();
}

// Dumps the op graph of the `module` to `filename` in DOT format.
absl::Status DumpOpGraphToFile(mlir::ModuleOp module,
const std::string& filename) {
std::string error_message;
auto output = mlir::openOutputFile(filename, &error_message);
if (!error_message.empty()) {
return errors::InvalidArgument("Failed to open file in ", filename);
}
mlir::PassManager pm(module.getContext());
pm.addPass(mlir::createPrintOpGraphPass(output->os()));
if (failed(pm.run(module))) {
return errors::Unknown("Failed to dump Op Graph from MLIR module.");
}
output->keep();
return absl::OkStatus();
}

absl::Status ConvertMLIRToTFLiteFlatBuffer(
const toco::ModelFlags& model_flags, toco::TocoFlags& toco_flags,
std::unique_ptr<mlir::MLIRContext> context,
mlir::OwningOpRef<mlir::ModuleOp> module,
const mlir::TFL::PassConfig& pass_config,
const std::unordered_set<std::string>& saved_model_tags,
std::string* result, std::unique_ptr<SavedModelBundle> saved_model_bundle,
const PyFunctionLibrary* quantization_py_function_lib) {
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
module.get(),
// rename once we enable the new converter feature flag.
absl::StrCat(toco_flags.dump_graphviz_dir(), "/toco_AT_IMPORT.dot")));
}

mlir::TFL::PassConfig pass_config_copy = pass_config;
pass_config_copy.outline_tf_while = true;

Expand All @@ -374,16 +351,11 @@ absl::Status ConvertMLIRToTFLiteFlatBuffer(
});

auto status = ConvertTFExecutorToTFLOrFlatbuffer(
module.get(), /*export_to_mlir=*/false, toco_flags, pass_config_copy,
std::move(context), std::move(module), toco_flags, pass_config_copy,
saved_model_tags, model_flags.saved_model_dir(),
std::move(saved_model_bundle), result, /*serialize_stablehlo_ops=*/false,
quantization_py_function_lib);
if (toco_flags.has_dump_graphviz_dir()) {
TF_RETURN_IF_ERROR(DumpOpGraphToFile(
// rename once we enable the new converter feature flag.
module.get(), absl::StrCat(toco_flags.dump_graphviz_dir(),
"/toco_AFTER_TRANSFORMATIONS.dot")));
}
std::move(saved_model_bundle), result,
/*serialize_stablehlo_ops=*/false, /*export_to_mlir=*/false,
/*print_function_result_mapping=*/false, quantization_py_function_lib);

return status;
}
Expand Down
Loading

0 comments on commit d5b1215

Please sign in to comment.