[go: nahoru, domu]

Skip to content

Commit

Permalink
Add disabled_checks attribute to XlaCallModule.
Browse files Browse the repository at this point in the history
This attribute encodes the disabled_checks option that was passed at
serialization time, and also controls safety checks when the module
is loaded and compiled. See docstring in xla_ops.cc.

Also bump the version number of XlaCallModule to 6.

From this version the XlaCallModule must have a non-empty `platforms`
attribute to encode the platforms for which the module was serialized.
Previously the front-ends could disable the platform checking
by setting an empty `platforms`.

We have not yet added a way to add more disabled checks using an environment variable.

PiperOrigin-RevId: 538421669
  • Loading branch information
gnecula authored and tensorflower-gardener committed Jun 7, 2023
1 parent 4993c55 commit d592d80
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 113 deletions.
3 changes: 2 additions & 1 deletion tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -21205,7 +21205,8 @@ platform argument (see `platforms`) nor the dimension arguments (see
DefaultValuedOptionalAttr<StrArrayAttr, "{}">:$dim_args_spec,
DefaultValuedOptionalAttr<StrArrayAttr, "{}">:$platforms,
DefaultValuedOptionalAttr<TF_SymbolRefArrayAttr, "{}">:$function_list,
DefaultValuedOptionalAttr<BoolAttr, "false">:$has_token_input_output
DefaultValuedOptionalAttr<BoolAttr, "false">:$has_token_input_output,
DefaultValuedOptionalAttr<StrArrayAttr, "{}">:$disabled_checks
);

let results = (outs
Expand Down
16 changes: 13 additions & 3 deletions tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1200,14 +1200,24 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) {
for (auto attr : op.getDimArgsSpec().getAsRange<StringAttr>()) {
dim_args_spec.push_back(attr.getValue().str());
}

std::vector<std::string> disabled_checks;
for (auto attr : op.getDisabledChecks().getAsRange<StringAttr>()) {
disabled_checks.push_back(attr.getValue().str());
}
std::vector<std::string> platforms;
for (auto attr : op.getPlatforms().getAsRange<StringAttr>()) {
platforms.push_back(attr.getValue().str());
}
// Always use the first platform. The assumption is that shape inference
// results should be the same regardless of which platform is chosen.
int platform_index = op.getPlatforms().size() > 1 ? 0 : -1;
// Very old versions of the op have an empty platforms attribute.
std::string loading_platform =
(platforms.empty() ? "CPU" : platforms.front());

auto l = tensorflow::XlaCallModuleLoader::Create(
&xla_call_module_context_, op.getVersion(), op.getModule().str(),
std::move(dim_args_spec), platform_index);
std::move(dim_args_spec), std::move(disabled_checks),
std::move(platforms), std::move(loading_platform));
if (!l.ok()) {
LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: "
<< l.status().ToString() << "\n");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,25 @@ tsl::StatusOr<OwningOpRef<ModuleOp>> DeserializeStablehlo(MLIRContext *context,
for (auto attr : op.getDimArgsSpec().getAsRange<StringAttr>()) {
dim_args_spec.push_back(attr.getValue().str());
}
std::vector<std::string> disabled_checks;
for (auto attr : op.getDisabledChecks().getAsRange<StringAttr>()) {
disabled_checks.push_back(attr.getValue().str());
}
std::vector<std::string> platforms;
for (auto attr : op.getPlatforms().getAsRange<StringAttr>()) {
platforms.push_back(attr.getValue().str());
}
// XlaCallModuleOp OpKernel will determine platform index when running
// TF2XLA. We don't know the device/platform type in this MLIR pass, so
// we set platform_index to -1.
TF_ASSIGN_OR_RETURN(auto loader,
tensorflow::XlaCallModuleLoader::Create(
context, static_cast<int>(op.getVersion()),
op.getModule().str(), dim_args_spec,
/*platform_index=*/-1));
// we set loading_platform to the first platform.
std::string loading_platform =
(platforms.empty() ? "CPU" : platforms.front());
TF_ASSIGN_OR_RETURN(
auto loader,
tensorflow::XlaCallModuleLoader::Create(
context, static_cast<int>(op.getVersion()), op.getModule().str(),
std::move(dim_args_spec), std::move(disabled_checks),
std::move(platforms), std::move(loading_platform)));
return std::move(*loader).module();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
// }
// }
// CHECK: call @main.2
%0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], function_list = [], has_token_input_output = false, module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = [], version = 4 : i64} : (tensor<f32>) -> tensor<*xf32>
%0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], function_list = [], disabled_checks = [], has_token_input_output = false, module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = ["CPU"], version = 6 : i64} : (tensor<f32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ func.func @xla_call_module(%arg0: tensor<f32>) -> tensor<*xf32> {
// }
// }
// expected-remark@+1 {{UNIMPLEMENTED: MlirHloBuilder does not support op call}}
%0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], function_list = [], has_token_input_output = false, module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = [], version = 4 : i64} : (tensor<f32>) -> tensor<*xf32>
%0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], function_list = [], disabled_checks = [], has_token_input_output = false, module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = ["CPU"], version = 6 : i64} : (tensor<f32>) -> tensor<*xf32>
func.return %0 : tensor<*xf32>
}

Expand Down
Loading

0 comments on commit d592d80

Please sign in to comment.