-
Notifications
You must be signed in to change notification settings - Fork 74k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lift HashTable ops as function arguments
In the graph execution mode, resource ops with the same `shared_name` attribute point to the same underlying resource. This is not true in the eager execution mode. This cl lifts HashTable ops as arguments, updates its caller and unifies the lifted ops with any existing in the caller function. PiperOrigin-RevId: 535495431
- Loading branch information
1 parent
4b6ea62
commit 41787b2
Showing
8 changed files
with
581 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
210 changes: 210 additions & 0 deletions
210
tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_hashtable_ops_as_args.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#include <memory> | ||
|
||
#include "absl/strings/str_cat.h" | ||
#include "llvm/ADT/StringRef.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project | ||
#include "mlir/IR/IRMapping.h" // from @llvm-project | ||
#include "mlir/Pass/Pass.h" // from @llvm-project | ||
#include "mlir/Support/LLVM.h" // from @llvm-project | ||
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" | ||
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" | ||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" | ||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" | ||
#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" | ||
|
||
namespace mlir { | ||
namespace quant { | ||
namespace { | ||
|
||
constexpr StringRef kSharedNameAttr = "shared_name"; | ||
|
||
class LiftHashTableOpsAsArgsPass | ||
: public PassWrapper<LiftHashTableOpsAsArgsPass, OperationPass<ModuleOp>> { | ||
public: | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftHashTableOpsAsArgsPass) | ||
explicit LiftHashTableOpsAsArgsPass() = default; | ||
|
||
StringRef getArgument() const final { | ||
// This is the argument used to refer to the pass in | ||
// the textual format (on the commandline for example). | ||
return "quant-lift-hashtable-ops-as-args"; | ||
} | ||
StringRef getDescription() const final { | ||
return "Lifts HashTable ops as function arguments."; | ||
} | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
// Checks if the given op is a Hashtable op. | ||
bool IsHashTableOp(Operation* op) { | ||
return llvm::isa<TF::HashTableOp, TF::HashTableV2Op, | ||
TF::MutableHashTableV2Op>(op); | ||
} | ||
|
||
// Checks if the function is the main or initializer function. | ||
bool IsMainOrInitializerFunction(ModuleOp module, func::FuncOp func) { | ||
if (func.getSymName().equals(tensorflow::kImportModelDefaultGraphFuncName) || | ||
func.getSymName().equals(kTfQuantSaveFuncName)) { | ||
return true; | ||
} | ||
|
||
for (func::FuncOp init_func : | ||
tf_saved_model::GetInitializerFunctions(module)) { | ||
if (func.getSymName().equals(init_func.getSymName())) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
// Checks if the function is only used by supported ops. Returns false when the | ||
// function has no uses. Currently, only PartitionedCall is supported. | ||
// TODO(b/284222309): Support lifting for functions called by control flow. | ||
bool UsedBySupportedOps(ModuleOp module, func::FuncOp func) { | ||
auto function_uses = | ||
SymbolTable::getSymbolUses(func, &module.getBodyRegion()); | ||
if (!function_uses.has_value()) return false; | ||
for (auto& function_use : function_uses.value()) { | ||
if (!llvm::isa<TF::PartitionedCallOp, TF::StatefulPartitionedCallOp>( | ||
function_use.getUser())) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
// Returns the `shared_name` attribute value if exists. If not, returns an | ||
// empty string. | ||
StringRef GetSharedName(Operation* op) { | ||
if (!op->hasAttrOfType<StringAttr>(kSharedNameAttr)) return ""; | ||
return op->getAttrOfType<StringAttr>(kSharedNameAttr).getValue(); | ||
} | ||
|
||
// Checks if the HashTable is initialized. This function assumes that the | ||
// HashTable is initialized if it appears in the initializer since it can't | ||
// check the actual value. | ||
bool IsResourceInitialized(ModuleOp module_op, Operation* hash_table) { | ||
StringRef shared_name = GetSharedName(hash_table); | ||
if (shared_name.empty()) return false; | ||
|
||
for (func::FuncOp init_func_op : | ||
tf_saved_model::GetInitializerFunctions(module_op)) { | ||
for (Operation& op : init_func_op.getBody().getOps()) { | ||
StringRef other_shared_name = GetSharedName(&op); | ||
if (IsHashTableOp(&op) && other_shared_name.equals(shared_name)) { | ||
return true; | ||
} | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
// Lifts HashTable ops in the target function as function arguments and returns | ||
// the lifted ops. These ops will then be added to the caller function and | ||
// passed to the target function. | ||
LogicalResult LiftHashTableOpsToArguments(ModuleOp module_op, | ||
func::FuncOp target_func) { | ||
if (!llvm::hasSingleElement(target_func)) return success(); | ||
if (!UsedBySupportedOps(module_op, target_func)) return success(); | ||
if (IsMainOrInitializerFunction(module_op, target_func)) return success(); | ||
|
||
llvm::StringMap<int> shared_name_to_arg_idx; | ||
llvm::SmallDenseMap<Operation*, int> lifted_op_to_arg_idx; | ||
Block& block = target_func.front(); | ||
auto func_type = target_func.getFunctionType(); | ||
|
||
for (Operation& op : block.without_terminator()) { | ||
StringRef shared_name = GetSharedName(&op); | ||
if (shared_name.empty() || !IsHashTableOp(&op)) continue; | ||
if (!IsResourceInitialized(module_op, &op)) continue; | ||
|
||
auto it = | ||
shared_name_to_arg_idx.insert({shared_name, block.getNumArguments()}); | ||
if (it.second) { | ||
auto resource_type = op.getResult(0).getType(); | ||
op.getResult(0).replaceAllUsesWith( | ||
block.addArgument(resource_type, op.getLoc())); | ||
AddEntryFunctionInput( | ||
absl::StrCat("hash_table_", it.first->getValue(), ":0"), target_func); | ||
// Avoid deleting the op here, clone it to the caller function first. | ||
lifted_op_to_arg_idx.insert({&op, it.first->getValue()}); | ||
} else { | ||
op.getResult(0).replaceAllUsesWith( | ||
block.getArgument(it.first->getValue())); | ||
op.erase(); | ||
} | ||
} | ||
if (lifted_op_to_arg_idx.empty()) return success(); | ||
|
||
// Update the function signature as well as its uses. | ||
target_func.setType(FunctionType::get(target_func.getContext(), | ||
block.getArgumentTypes(), | ||
func_type.getResults())); | ||
|
||
IRMapping mapping; | ||
OpBuilder builder(module_op); | ||
OpBuilder::InsertionGuard g(builder); | ||
// The function has been checked to have at least one use. | ||
auto function_uses = | ||
SymbolTable::getSymbolUses(target_func, &module_op.getBodyRegion()); | ||
for (auto& function_use : function_uses.value()) { | ||
auto call_op = function_use.getUser(); | ||
auto caller_func = call_op->getParentOfType<func::FuncOp>(); | ||
if (!caller_func) return failure(); | ||
|
||
builder.setInsertionPoint(call_op); | ||
for (auto [lifted_op, arg_idx] : lifted_op_to_arg_idx) { | ||
auto new_op = builder.clone(*lifted_op, mapping); | ||
call_op->insertOperands(arg_idx, new_op->getResult(0)); | ||
} | ||
|
||
// Try to lift recursively until the main function. | ||
if (failed(LiftHashTableOpsToArguments(module_op, caller_func))) { | ||
return failure(); | ||
} | ||
} | ||
|
||
// Erase the lifted operations explicitly. | ||
for (auto [lifted_op, arg_idx] : lifted_op_to_arg_idx) { | ||
lifted_op->erase(); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
void LiftHashTableOpsAsArgsPass::runOnOperation() { | ||
auto module_op = getOperation(); | ||
|
||
for (auto func_op : module_op.getOps<func::FuncOp>()) { | ||
if (failed(LiftHashTableOpsToArguments(module_op, func_op))) { | ||
signalPassFailure(); | ||
return; | ||
} | ||
} | ||
} | ||
|
||
static PassRegistration<LiftHashTableOpsAsArgsPass> pass; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftHashTableOpsAsArgsPass() { | ||
return std::make_unique<LiftHashTableOpsAsArgsPass>(); | ||
} | ||
|
||
} // namespace quant | ||
} // namespace mlir |
139 changes: 139 additions & 0 deletions
139
tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_duplicate_resource_ops.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
==============================================================================*/ | ||
#include <memory> | ||
#include <string> | ||
|
||
#include "llvm/ADT/StringRef.h" | ||
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" | ||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" | ||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" | ||
|
||
namespace mlir { | ||
namespace quant { | ||
namespace { | ||
|
||
using ::mlir::tf_executor::GraphOp; | ||
using ::mlir::tf_executor::IslandOp; | ||
|
||
constexpr StringRef kSharedNameAttr = "shared_name"; | ||
|
||
class MergeDuplicateResourceOpsPass | ||
: public PassWrapper<MergeDuplicateResourceOpsPass, | ||
OperationPass<func::FuncOp>> { | ||
public: | ||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeDuplicateResourceOpsPass) | ||
|
||
StringRef getArgument() const final { | ||
return "quant-merge-duplicate-resource-ops"; | ||
} | ||
|
||
StringRef getDescription() const final { | ||
return "Merge resource ops that have the same shared name."; | ||
} | ||
|
||
void runOnOperation() override; | ||
}; | ||
|
||
// Checks if the island op contains a resource op like Variable or Hashtable | ||
// and returns that resource op. Otherwise, returns null. | ||
Operation* GetResourceOp(Operation* op) { | ||
// Check if the island has only one block thats contain two ops, including | ||
// one resource op and one Yield op. | ||
auto island_op = llvm::dyn_cast_or_null<IslandOp>(op); | ||
if (!island_op || !island_op.getBody().hasOneBlock()) return nullptr; | ||
auto& island_block = island_op.getBody().front(); | ||
if (++island_block.begin() != --island_block.end()) return nullptr; | ||
|
||
Operation* resource_op = &island_block.front(); | ||
if (llvm::isa<TF::VarHandleOp, TF::HashTableOp, TF::HashTableV2Op, | ||
TF::MutableHashTableV2Op>(resource_op)) { | ||
return resource_op; | ||
} | ||
return nullptr; | ||
} | ||
|
||
// Returns the `shared_name` attribute value if exists. If not, returns an | ||
// empty string. | ||
StringRef GetSharedName(Operation* op) { | ||
if (!op->hasAttrOfType<StringAttr>(kSharedNameAttr)) return ""; | ||
return op->getAttrOfType<StringAttr>(kSharedNameAttr).getValue(); | ||
} | ||
|
||
// Gets the GraphOp from the function op. Returns an empty op iff it doesn't | ||
// exist. | ||
// TODO(b/284222084): Move executor dialect utilities to a new library. | ||
GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { | ||
if (func_op->getNumRegions() == 0 || func_op.getBody().empty()) return {}; | ||
|
||
auto graph_op_range = func_op.front().without_terminator(); | ||
if (llvm::hasSingleElement(graph_op_range)) { | ||
// The pass runs on a valid tf_executor dialect, so the op should be the | ||
// GraphOp. | ||
return cast<GraphOp>(graph_op_range.begin()); | ||
} | ||
|
||
return {}; | ||
} | ||
|
||
void MergeDuplicateResourceOpsPass::runOnOperation() { | ||
func::FuncOp func_op = getOperation(); | ||
GraphOp graph_op = GetGraphOpFromFuncOp(func_op); | ||
if (!graph_op) return; | ||
|
||
llvm::StringMap<Operation*> shared_name_to_resource; | ||
llvm::SmallVector<Operation*> ops_to_remove; | ||
for (Operation& op : graph_op.GetBody().without_terminator()) { | ||
Operation* resource_op = GetResourceOp(&op); | ||
if (!resource_op) continue; | ||
StringRef shared_name = GetSharedName(resource_op); | ||
if (shared_name.empty()) continue; | ||
|
||
if (!shared_name_to_resource.contains(shared_name)) { | ||
shared_name_to_resource[shared_name] = resource_op; | ||
continue; | ||
} | ||
|
||
auto existing_resource = shared_name_to_resource[shared_name]; | ||
if (resource_op->getName().getStringRef() != | ||
existing_resource->getName().getStringRef() || | ||
resource_op->getResult(0).getType() != | ||
existing_resource->getResult(0).getType()) { | ||
resource_op->emitOpError( | ||
"This op has the same `shared_name` but different type with another " | ||
"resource op in the function"); | ||
signalPassFailure(); | ||
return; | ||
} | ||
op.replaceAllUsesWith(existing_resource->getParentOp()->getResults()); | ||
ops_to_remove.push_back(&op); | ||
} | ||
|
||
// Remove op after the loop to avoid crash. | ||
for (Operation* op : ops_to_remove) { | ||
op->erase(); | ||
} | ||
} | ||
|
||
static PassRegistration<MergeDuplicateResourceOpsPass> pass{}; | ||
|
||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<func::FuncOp>> | ||
CreateMergeDuplicateResourceOpsPass() { | ||
return std::make_unique<MergeDuplicateResourceOpsPass>(); | ||
} | ||
|
||
} // namespace quant | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.