[go: nahoru, domu]

Skip to content

Commit

Permalink
Add additional debugging options to the TFLite Converter.
Browse files Browse the repository at this point in the history
Adds four new debug-related options that enable printing MLIR to stdout:
* mlir_print_ir_before: Prints MLIR before each pass. Accepts a regex used for
  selectively matching against pass names.
* mlir_print_ir_after:  Prints MLIR after each pass. Accepts a regex used for
  selectively matching against pass names.
* mlir_print_ir_module_scope: Determines whether to print the top-level
  operation when printing IR.
* mlir_elide_elementsattrs_if_larger: Elides ElementsAttrs with number of
  elements larger than the provided integer.

The existing options are untouched in their functionality and control of
dumping MLIR to files is independent of that for printing to stdout. Both
can be done simultaneously if the user so desires.

A future CL will plumb the new options through the Python API.

PiperOrigin-RevId: 545875450
  • Loading branch information
arfaian authored and tensorflower-gardener committed Jul 6, 2023
1 parent 8b3adcb commit 1297d6e
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 10 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/debug/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ tf_cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
Expand Down
58 changes: 52 additions & 6 deletions tensorflow/compiler/mlir/lite/debug/debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "re2/re2.h"
#include "re2/re2.h" // IWYU pragma: keep
#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/tsl/lib/io/buffered_file.h"
Expand All @@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/tsl/platform/path.h"
#include "tensorflow/tsl/platform/status.h"
#include "tensorflow/tsl/platform/stringpiece.h"
// IWYU pragma: no_include "util/regexp/re2/re2.h"

namespace tensorflow {
namespace {
Expand Down Expand Up @@ -261,12 +262,39 @@ class DumpInstrumentation : public mlir::PassInstrumentation {
bool printed_ = false;
};

std::function<bool(mlir::Pass*, mlir::Operation*)> CreatePrintIRFun(
const std::string& pass_regex) {
std::function<bool(mlir::Pass*, mlir::Operation*)> fun;
if (pass_regex.empty()) {
return fun;
}
return [pr = pass_regex](mlir::Pass* p, mlir::Operation*) {
static const RE2* const re = new RE2(pr);
if (RE2::FullMatch(p->getName(), *re)) {
return true;
}
return false;
};
}

} // namespace

void InitPassManager(mlir::PassManager& pm,
const converter::DebugOptions& options) {
const converter::DebugOptions& options,
llvm::raw_ostream& out) {
std::string dump_dir = options.mlir_dump_dir();
if (!dump_dir.empty()) {

bool dump_to_dir = !dump_dir.empty();
bool print_to_stdout = !options.mlir_print_ir_before().empty() ||
!options.mlir_print_ir_after().empty();

if (dump_to_dir || print_to_stdout) {
// Necessary for maintaining sequence of passes when dumping MLIR to files
// or stdout.
pm.getContext()->disableMultithreading();
}

if (dump_to_dir) {
dump_dir = tsl::io::JoinPath(
dump_dir, absl::FormatTime("%E4Y%m%d_%H%M%E6S", absl::Now(),
absl::LocalTimeZone()));
Expand All @@ -287,10 +315,28 @@ void InitPassManager(mlir::PassManager& pm,
pm.addInstrumentation(std::make_unique<DumpInstrumentation>(
dump_dir, options.mlir_dump_pass_regex(),
options.mlir_dump_func_regex()));
}

// Necessary for maintaining MLIR dump file name sequence to represent
// order.
pm.getContext()->disableMultithreading();
if (print_to_stdout) {
std::function<bool(mlir::Pass*, mlir::Operation*)>
should_print_ir_before_pass(
CreatePrintIRFun(options.mlir_print_ir_before()));
std::function<bool(mlir::Pass*, mlir::Operation*)>
should_print_ir_after_pass(
CreatePrintIRFun(options.mlir_print_ir_after()));

mlir::OpPrintingFlags opPrintingFlags = mlir::OpPrintingFlags();

if (options.has_mlir_elide_elementsattrs_if_larger()) {
opPrintingFlags.elideLargeElementsAttrs(
options.mlir_elide_elementsattrs_if_larger());
}

pm.enableIRPrinting(should_print_ir_before_pass, should_print_ir_after_pass,
options.mlir_print_ir_module_scope(),
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure=*/false, out,
opPrintingFlags);
}

// Enable pass timing. Note: MLIR expects `mlir::PassManager::enableTiming` to
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/compiler/mlir/lite/debug/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_DEBUG_DEBUG_H_

#include "llvm/Support/raw_ostream.h"
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h"

namespace tensorflow {

// Initializes the pass manager with default options that make debugging easier.
// The `out` method parameter is exposed for testing purposes and not intended
// to be specified by client code.
void InitPassManager(mlir::PassManager& pm,
const converter::DebugOptions& options);
const converter::DebugOptions& options,
llvm::raw_ostream& out = llvm::outs());

} // namespace tensorflow

Expand Down
19 changes: 18 additions & 1 deletion tensorflow/compiler/mlir/lite/debug/debug_options.proto
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package tensorflow.converter;

// Additional parameters that control the debug behavior of the Converter.
//
// Next ID: 5
// Next ID: 9
message DebugOptions {
// If not empty, dumps MLIR to the specified directory. The initial state of
// the MLIR after import will be dumped at the beginning of each pass manager
Expand All @@ -41,4 +41,21 @@ message DebugOptions {

// If true, report the execution time of each MLIR pass.
optional bool mlir_enable_timing = 4 [default = false];

// Prints MLIR before specified passes. Supports regular expressions for
// matching against the names of the desired passes.
optional string mlir_print_ir_before = 5 [default = ""];

// Prints MLIR after specified passes. Supports regular expressions for
// matching against the names of the desired passes. Currently only prints
// after a pass if the MLIR is mutated.
optional string mlir_print_ir_after = 6 [default = ""];

// If true, always print the top-level operation when printing IR for
// print_ir_[before|after].
optional bool mlir_print_ir_module_scope = 7 [default = true];

// Elide ElementsAttrs with \"...\" that have more elements than the given
// upper limit.
optional int64 mlir_elide_elementsattrs_if_larger = 8;
}
118 changes: 116 additions & 2 deletions tensorflow/compiler/mlir/lite/debug/debug_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ limitations under the License.

#include "tensorflow/compiler/mlir/lite/debug/debug.h"

#include <stdint.h>

#include <cstdlib>
#include <memory>
#include <string>
Expand All @@ -24,11 +26,15 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinDialect.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OpDefinition.h" // from @llvm-project
Expand All @@ -49,6 +55,7 @@ limitations under the License.
namespace tensorflow {
namespace {

using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::Not;

Expand All @@ -59,6 +66,16 @@ class NopPass : public mlir::PassWrapper<NopPass, mlir::OperationPass<>> {
void runOnOperation() override {}
};

class MutatePass : public mlir::PassWrapper<MutatePass, mlir::OperationPass<>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MutatePass)

void runOnOperation() override {
mlir::OpBuilder builder(&getContext());
getOperation()->setAttr("tfl.random_attr", builder.getUnitAttr());
}
};

class AlwaysFailPass
: public mlir::PassWrapper<AlwaysFailPass, mlir::OperationPass<>> {
public:
Expand All @@ -74,6 +91,7 @@ class InitPassManagerTest : public testing::Test {
mlir::registerPassManagerCLOptions();
mlir::DialectRegistry registry;
registry.insert<mlir::BuiltinDialect>();
registry.insert<mlir::arith::ArithDialect>();
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::TFL::TensorFlowLiteDialect>();
return registry;
Expand All @@ -87,8 +105,13 @@ class InitPassManagerTest : public testing::Test {
auto func = builder.create<mlir::func::FuncOp>( //
builder.getUnknownLoc(), "main", builder.getFunctionType({}, {}));
func->setAttr("tfl.func", builder.getUnitAttr());

builder.setInsertionPointToStart(func.addEntryBlock());
llvm::SmallVector<int> shape{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
builder.create<mlir::arith::ConstantOp>(
builder.getUnknownLoc(),
mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(shape.size(), builder.getI32Type()),
shape));
builder.create<mlir::func::ReturnOp>(builder.getUnknownLoc());
}

Expand Down Expand Up @@ -139,7 +162,7 @@ TEST_F(InitPassManagerTest, CrashReproducer) {
EXPECT_THAT(mlir_dump, Not(IsEmpty()));
}

TEST_F(InitPassManagerTest, Dump) {
TEST_F(InitPassManagerTest, DumpToDir) {
converter::DebugOptions debug_options;
*debug_options.mutable_mlir_dump_dir() = path_;
*debug_options.mutable_mlir_dump_pass_regex() = R"(.*NopPass)";
Expand Down Expand Up @@ -173,5 +196,96 @@ TEST_F(InitPassManagerTest, Dump) {
}
}

TEST_F(InitPassManagerTest, PrintIRBeforeEverything) {
converter::DebugOptions debug_options;
*debug_options.mutable_mlir_print_ir_before() = R"(.*)";
std::string captured_out;
llvm::raw_string_ostream out(captured_out);

mlir::PassManager pm(&context_);
InitPassManager(pm, debug_options, out);
pm.addPass(std::make_unique<NopPass>());
ASSERT_TRUE(mlir::succeeded(pm.run(*module_)));

EXPECT_THAT(
captured_out,
HasSubstr("IR Dump Before tensorflow::(anonymous namespace)::NopPass"));
EXPECT_THAT(captured_out,
Not(HasSubstr(
"IR Dump After tensorflow::(anonymous namespace)::NopPass")));
}

TEST_F(InitPassManagerTest, PrintIRAfterEverything) {
converter::DebugOptions debug_options;
*debug_options.mutable_mlir_print_ir_after() = R"(.*)";
std::string captured_out;
llvm::raw_string_ostream out(captured_out);

mlir::PassManager pm(&context_);
InitPassManager(pm, debug_options, out);
pm.addPass(std::make_unique<MutatePass>());
ASSERT_TRUE(mlir::succeeded(pm.run(*module_)));

EXPECT_THAT(
captured_out,
HasSubstr("IR Dump After tensorflow::(anonymous namespace)::MutatePass"));
EXPECT_THAT(
captured_out,
Not(HasSubstr(
"IR Dump Before tensorflow::(anonymous namespace)::MutatePass")));
}

TEST_F(InitPassManagerTest, PrintIRBeforeAndAfterEverything) {
converter::DebugOptions debug_options;
*debug_options.mutable_mlir_print_ir_before() = R"(.*)";
*debug_options.mutable_mlir_print_ir_after() = R"(.*)";
std::string captured_out;
llvm::raw_string_ostream out(captured_out);

mlir::PassManager pm(&context_);
InitPassManager(pm, debug_options, out);
pm.addPass(std::make_unique<MutatePass>());
ASSERT_TRUE(mlir::succeeded(pm.run(*module_)));

EXPECT_THAT(
captured_out,
HasSubstr("IR Dump After tensorflow::(anonymous namespace)::MutatePass"));
EXPECT_THAT(
captured_out,
HasSubstr(
"IR Dump Before tensorflow::(anonymous namespace)::MutatePass"));
}

TEST_F(InitPassManagerTest, ElideLargeElementAttrs) {
converter::DebugOptions debug_options;
*debug_options.mutable_mlir_print_ir_before() = R"(.*)";
debug_options.set_mlir_elide_elementsattrs_if_larger(5);
std::string captured_out;
llvm::raw_string_ostream out(captured_out);

mlir::PassManager pm(&context_);
InitPassManager(pm, debug_options, out);
pm.addPass(std::make_unique<MutatePass>());
ASSERT_TRUE(mlir::succeeded(pm.run(*module_)));

EXPECT_THAT(captured_out, HasSubstr("dense_resource<__elided__>"));
}

TEST_F(InitPassManagerTest, DontElideSmallerElementAttrs) {
converter::DebugOptions debug_options;
*debug_options.mutable_mlir_print_ir_before() = R"(.*)";
debug_options.set_mlir_elide_elementsattrs_if_larger(11);
std::string captured_out;
llvm::raw_string_ostream out(captured_out);

mlir::PassManager pm(&context_);
InitPassManager(pm, debug_options, out);
pm.addPass(std::make_unique<MutatePass>());
ASSERT_TRUE(mlir::succeeded(pm.run(*module_)));

EXPECT_THAT(captured_out,
HasSubstr("dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]>"));
}

} // namespace
} // namespace tensorflow

0 comments on commit 1297d6e

Please sign in to comment.