[go: nahoru, domu]

Skip to content

Commit

Permalink
Allow XlaHostComputeOp to the allowed list in mark_ops_for_outside_co…
Browse files Browse the repository at this point in the history
…mpilation.cc

PiperOrigin-RevId: 627746297
  • Loading branch information
deqiangc authored and tensorflower-gardener committed Apr 24, 2024
1 parent 51a4843 commit dadbf6d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,17 @@ func.func @unsupported_op_gpu_cluster() -> tensor<i32> {
}) {allow_soft_placement = true, _xla_compile_device_type = "GPU"} : () -> tensor<i32>
func.return %0 : tensor<i32>
}

// CHECK-LABEL: func @xla_host_compute
func.func @xla_host_compute(%arg0: tensor<i32>) {
"tf_device.cluster"() ({
%cst = "tf.Const"() {value = dense<16> : tensor<i32>} : () -> tensor<i32>
// CHECK: tf.XlaHostCompute
// CHECK-SAME:_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"]
"tf.XlaHostCompute"(%cst) <{ancestors = [], cost_estimate_ns = 1000000 : i64, key = "_host_callback", recv_key = "", send_key = "", shapes = [], tpu_core = 0 : i64}> {_xla_original_oc_node_name = "hcb0", _xla_token_input_nodes = ["_xla_token_arg_node"]} : (tensor<i32>) -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
func.return
}


Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ void AddSupportedFunctionalOps(MLIRContext* context,
OperationName(mlir::TF::WhileRegionOp::getOperationName(), context));
supported_ops->insert(
OperationName(mlir::TF::XlaCallModuleOp::getOperationName(), context));
supported_ops->insert(
OperationName(mlir::TF::XlaHostComputeOp::getOperationName(), context));
supported_ops->insert(
OperationName(mlir::TF::XlaReduceOp::getOperationName(), context));
supported_ops->insert(
Expand Down

0 comments on commit dadbf6d

Please sign in to comment.