[go: nahoru, domu]

Skip to content

Commit

Permalink
Fix bug where shlo FuncOp without explicit visibility attribute would…
Browse files Browse the repository at this point in the history
… segv.

Relaxes the search criteria for the FuncOp with the desired semantics to
include the case where visibility is undefined which implicitly means `public`
visibility.

PiperOrigin-RevId: 550026464
  • Loading branch information
arfaian authored and tensorflower-gardener committed Jul 21, 2023
1 parent 7cab1c1 commit 3db0647
Showing 1 changed file with 39 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,27 @@ class RemoveCustomCallWithSharding
}
};

namespace {

bool IsShloMainFuncOp(mlir::func::FuncOp func_op) {
if (func_op == nullptr) {
return false;
}

if (!func_op.getSymName().contains(kStablehloModuleDefaultEntryFuncName)) {
return false;
}

if (func_op.getSymVisibility() == "nested" ||
func_op.getSymVisibility() == "private") {
return false;
}

return true;
}

} // namespace

class ConvertTFXlaCallModuleOp
: public mlir::OpRewritePattern<mlir::TF::XlaCallModuleOp> {
public:
Expand All @@ -87,10 +108,23 @@ class ConvertTFXlaCallModuleOp
}
SymbolTable parent_module_symbol_table(module_op_);
SymbolTable stablehlo_module_symbol_table(stablehlo_module_op.get());
if (stablehlo_module_symbol_table.lookup<mlir::func::FuncOp>(
kStablehloModuleDefaultEntryFuncName) == nullptr) {
return rewriter.notifyMatchFailure(
op, "could not find main function in XlaCallModuleOp");
{
auto main_func_op =
stablehlo_module_symbol_table.lookup<mlir::func::FuncOp>(
kStablehloModuleDefaultEntryFuncName);
// TODO(b/291988976): move enforcement of this variable outside of this
// rewrite pattern such that it's only checked once. Currently, this
// approach results in duplicate error messages as this pattern executes
// more than once.
if (!IsShloMainFuncOp(main_func_op)) {
auto error_msg =
"'main' FuncOp in XlaCallModuleOp missing or has visibility other "
"than 'public'";
if (main_func_op) {
main_func_op->emitError(error_msg);
}
return rewriter.notifyMatchFailure(op, error_msg);
}
}
mlir::Builder stablehlo_builder(stablehlo_module_op.get().getContext());
// Rename XlaCallModuleOp's functions to avoid naming conflicts.
Expand All @@ -111,9 +145,7 @@ class ConvertTFXlaCallModuleOp
for (auto func_op :
stablehlo_module_op.get().getOps<mlir::func::FuncOp>()) {
mlir::func::FuncOp cloned_func_op = func_op.clone();
if (cloned_func_op.getSymName().contains(
kStablehloModuleDefaultEntryFuncName) &&
cloned_func_op.getSymVisibility() == "public") {
if (IsShloMainFuncOp(cloned_func_op)) {
main_fn = cloned_func_op;
}
cloned_func_op.setSymVisibility(
Expand Down

0 comments on commit 3db0647

Please sign in to comment.