[go: nahoru, domu]

Skip to content

Commit

Permalink
Lift HashTable ops as function arguments
Browse files Browse the repository at this point in the history
In the graph execution mode, resource ops with the same `shared_name` attribute point to the same underlying resource. This is not true in the eager execution mode. This cl lifts HashTable ops as arguments, updates its caller and unifies the lifted ops with any existing in the caller function.

PiperOrigin-RevId: 535495431
  • Loading branch information
thaink authored and tensorflower-gardener committed May 26, 2023
1 parent 4b6ea62 commit 41787b2
Show file tree
Hide file tree
Showing 8 changed files with 581 additions and 12 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/quantization/tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,13 @@ cc_library(
"passes/insert_restore_op.cc",
"passes/insert_save_op.cc",
"passes/issue_ids_of_custom_aggregation_ops.cc",
"passes/lift_hashtable_ops_as_args.cc",
"passes/lift_quantizable_spots_as_functions.cc",
"passes/lift_quantizable_spots_as_functions.inc",
"passes/lift_quantizable_spots_as_functions_drq.cc",
"passes/lift_quantizable_spots_as_functions_drq.inc",
"passes/mark_functions_noinline.cc",
"passes/merge_duplicate_resource_ops.cc",
"passes/merge_initializer_function_ops_to_main.cc",
"passes/merge_save_function_ops_to_main.cc",
"passes/optimize.cc",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>

#include "absl/strings/str_cat.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/IRMapping.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"

namespace mlir {
namespace quant {
namespace {

constexpr StringRef kSharedNameAttr = "shared_name";

class LiftHashTableOpsAsArgsPass
: public PassWrapper<LiftHashTableOpsAsArgsPass, OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftHashTableOpsAsArgsPass)
explicit LiftHashTableOpsAsArgsPass() = default;

StringRef getArgument() const final {
// This is the argument used to refer to the pass in
// the textual format (on the commandline for example).
return "quant-lift-hashtable-ops-as-args";
}
StringRef getDescription() const final {
return "Lifts HashTable ops as function arguments.";
}

void runOnOperation() override;
};

// Checks if the given op is a Hashtable op.
bool IsHashTableOp(Operation* op) {
return llvm::isa<TF::HashTableOp, TF::HashTableV2Op,
TF::MutableHashTableV2Op>(op);
}

// Checks if the function is the main or initializer function.
bool IsMainOrInitializerFunction(ModuleOp module, func::FuncOp func) {
if (func.getSymName().equals(tensorflow::kImportModelDefaultGraphFuncName) ||
func.getSymName().equals(kTfQuantSaveFuncName)) {
return true;
}

for (func::FuncOp init_func :
tf_saved_model::GetInitializerFunctions(module)) {
if (func.getSymName().equals(init_func.getSymName())) {
return true;
}
}
return false;
}

// Checks if the function is only used by supported ops. Returns false when the
// function has no uses. Currently, only PartitionedCall is supported.
// TODO(b/284222309): Support lifting for functions called by control flow.
bool UsedBySupportedOps(ModuleOp module, func::FuncOp func) {
auto function_uses =
SymbolTable::getSymbolUses(func, &module.getBodyRegion());
if (!function_uses.has_value()) return false;
for (auto& function_use : function_uses.value()) {
if (!llvm::isa<TF::PartitionedCallOp, TF::StatefulPartitionedCallOp>(
function_use.getUser())) {
return false;
}
}
return true;
}

// Returns the `shared_name` attribute value if exists. If not, returns an
// empty string.
StringRef GetSharedName(Operation* op) {
if (!op->hasAttrOfType<StringAttr>(kSharedNameAttr)) return "";
return op->getAttrOfType<StringAttr>(kSharedNameAttr).getValue();
}

// Checks if the HashTable is initialized. This function assumes that the
// HashTable is initialized if it appears in the initializer since it can't
// check the actual value.
bool IsResourceInitialized(ModuleOp module_op, Operation* hash_table) {
StringRef shared_name = GetSharedName(hash_table);
if (shared_name.empty()) return false;

for (func::FuncOp init_func_op :
tf_saved_model::GetInitializerFunctions(module_op)) {
for (Operation& op : init_func_op.getBody().getOps()) {
StringRef other_shared_name = GetSharedName(&op);
if (IsHashTableOp(&op) && other_shared_name.equals(shared_name)) {
return true;
}
}
}
return false;
}

// Lifts HashTable ops in the target function as function arguments and returns
// the lifted ops. These ops will then be added to the caller function and
// passed to the target function.
LogicalResult LiftHashTableOpsToArguments(ModuleOp module_op,
func::FuncOp target_func) {
if (!llvm::hasSingleElement(target_func)) return success();
if (!UsedBySupportedOps(module_op, target_func)) return success();
if (IsMainOrInitializerFunction(module_op, target_func)) return success();

llvm::StringMap<int> shared_name_to_arg_idx;
llvm::SmallDenseMap<Operation*, int> lifted_op_to_arg_idx;
Block& block = target_func.front();
auto func_type = target_func.getFunctionType();

for (Operation& op : block.without_terminator()) {
StringRef shared_name = GetSharedName(&op);
if (shared_name.empty() || !IsHashTableOp(&op)) continue;
if (!IsResourceInitialized(module_op, &op)) continue;

auto it =
shared_name_to_arg_idx.insert({shared_name, block.getNumArguments()});
if (it.second) {
auto resource_type = op.getResult(0).getType();
op.getResult(0).replaceAllUsesWith(
block.addArgument(resource_type, op.getLoc()));
AddEntryFunctionInput(
absl::StrCat("hash_table_", it.first->getValue(), ":0"), target_func);
// Avoid deleting the op here, clone it to the caller function first.
lifted_op_to_arg_idx.insert({&op, it.first->getValue()});
} else {
op.getResult(0).replaceAllUsesWith(
block.getArgument(it.first->getValue()));
op.erase();
}
}
if (lifted_op_to_arg_idx.empty()) return success();

// Update the function signature as well as its uses.
target_func.setType(FunctionType::get(target_func.getContext(),
block.getArgumentTypes(),
func_type.getResults()));

IRMapping mapping;
OpBuilder builder(module_op);
OpBuilder::InsertionGuard g(builder);
// The function has been checked to have at least one use.
auto function_uses =
SymbolTable::getSymbolUses(target_func, &module_op.getBodyRegion());
for (auto& function_use : function_uses.value()) {
auto call_op = function_use.getUser();
auto caller_func = call_op->getParentOfType<func::FuncOp>();
if (!caller_func) return failure();

builder.setInsertionPoint(call_op);
for (auto [lifted_op, arg_idx] : lifted_op_to_arg_idx) {
auto new_op = builder.clone(*lifted_op, mapping);
call_op->insertOperands(arg_idx, new_op->getResult(0));
}

// Try to lift recursively until the main function.
if (failed(LiftHashTableOpsToArguments(module_op, caller_func))) {
return failure();
}
}

// Erase the lifted operations explicitly.
for (auto [lifted_op, arg_idx] : lifted_op_to_arg_idx) {
lifted_op->erase();
}

return success();
}

void LiftHashTableOpsAsArgsPass::runOnOperation() {
auto module_op = getOperation();

for (auto func_op : module_op.getOps<func::FuncOp>()) {
if (failed(LiftHashTableOpsToArguments(module_op, func_op))) {
signalPassFailure();
return;
}
}
}

static PassRegistration<LiftHashTableOpsAsArgsPass> pass;

} // namespace

std::unique_ptr<OperationPass<ModuleOp>> CreateLiftHashTableOpsAsArgsPass() {
return std::make_unique<LiftHashTableOpsAsArgsPass>();
}

} // namespace quant
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <string>

#include "llvm/ADT/StringRef.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

namespace mlir {
namespace quant {
namespace {

using ::mlir::tf_executor::GraphOp;
using ::mlir::tf_executor::IslandOp;

constexpr StringRef kSharedNameAttr = "shared_name";

class MergeDuplicateResourceOpsPass
: public PassWrapper<MergeDuplicateResourceOpsPass,
OperationPass<func::FuncOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeDuplicateResourceOpsPass)

StringRef getArgument() const final {
return "quant-merge-duplicate-resource-ops";
}

StringRef getDescription() const final {
return "Merge resource ops that have the same shared name.";
}

void runOnOperation() override;
};

// Checks if the island op contains a resource op like Variable or Hashtable
// and returns that resource op. Otherwise, returns null.
Operation* GetResourceOp(Operation* op) {
// Check if the island has only one block thats contain two ops, including
// one resource op and one Yield op.
auto island_op = llvm::dyn_cast_or_null<IslandOp>(op);
if (!island_op || !island_op.getBody().hasOneBlock()) return nullptr;
auto& island_block = island_op.getBody().front();
if (++island_block.begin() != --island_block.end()) return nullptr;

Operation* resource_op = &island_block.front();
if (llvm::isa<TF::VarHandleOp, TF::HashTableOp, TF::HashTableV2Op,
TF::MutableHashTableV2Op>(resource_op)) {
return resource_op;
}
return nullptr;
}

// Returns the `shared_name` attribute value if exists. If not, returns an
// empty string.
StringRef GetSharedName(Operation* op) {
if (!op->hasAttrOfType<StringAttr>(kSharedNameAttr)) return "";
return op->getAttrOfType<StringAttr>(kSharedNameAttr).getValue();
}

// Gets the GraphOp from the function op. Returns an empty op iff it doesn't
// exist.
// TODO(b/284222084): Move executor dialect utilities to a new library.
GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) {
if (func_op->getNumRegions() == 0 || func_op.getBody().empty()) return {};

auto graph_op_range = func_op.front().without_terminator();
if (llvm::hasSingleElement(graph_op_range)) {
// The pass runs on a valid tf_executor dialect, so the op should be the
// GraphOp.
return cast<GraphOp>(graph_op_range.begin());
}

return {};
}

void MergeDuplicateResourceOpsPass::runOnOperation() {
func::FuncOp func_op = getOperation();
GraphOp graph_op = GetGraphOpFromFuncOp(func_op);
if (!graph_op) return;

llvm::StringMap<Operation*> shared_name_to_resource;
llvm::SmallVector<Operation*> ops_to_remove;
for (Operation& op : graph_op.GetBody().without_terminator()) {
Operation* resource_op = GetResourceOp(&op);
if (!resource_op) continue;
StringRef shared_name = GetSharedName(resource_op);
if (shared_name.empty()) continue;

if (!shared_name_to_resource.contains(shared_name)) {
shared_name_to_resource[shared_name] = resource_op;
continue;
}

auto existing_resource = shared_name_to_resource[shared_name];
if (resource_op->getName().getStringRef() !=
existing_resource->getName().getStringRef() ||
resource_op->getResult(0).getType() !=
existing_resource->getResult(0).getType()) {
resource_op->emitOpError(
"This op has the same `shared_name` but different type with another "
"resource op in the function");
signalPassFailure();
return;
}
op.replaceAllUsesWith(existing_resource->getParentOp()->getResults());
ops_to_remove.push_back(&op);
}

// Remove op after the loop to avoid crash.
for (Operation* op : ops_to_remove) {
op->erase();
}
}

static PassRegistration<MergeDuplicateResourceOpsPass> pass{};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>>
CreateMergeDuplicateResourceOpsPass() {
return std::make_unique<MergeDuplicateResourceOpsPass>();
}

} // namespace quant
} // namespace mlir
10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateConvertTpuModelToCpuPass();
// model quantization.
std::unique_ptr<OperationPass<ModuleOp>> CreateCastBf16OpsToF32Pass();

// Creates a pass that lifts HashTable ops as function arguments. In the graph
// execution mode, resource ops with the same `shared_name` attribute point to
// the same underlying resource. This is not true in the eager execution mode.
// Lifting resource ops as arguments will help unifying them across functions.
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftHashTableOpsAsArgsPass();

// Creates a pass that merges duplicate resource ops in each function. Two
// resource ops are considered duplicated if they have the same `shared_name`.
std::unique_ptr<OperationPass<func::FuncOp>>
CreateMergeDuplicateResourceOpsPass();
} // namespace quant
} // namespace mlir

Expand Down
Loading

0 comments on commit 41787b2

Please sign in to comment.