diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 3f4aadae425640..6a70bda929e35c 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -986,6 +986,7 @@ cc_library( name = "tpu_embedding_ops", srcs = ["tpu_embedding_ops.cc"], deps = [ + "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_context", "//tensorflow/compiler/tf2xla:xla_helpers", diff --git a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc index 364c1eefb9d005..f23215c086b513 100644 --- a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/xla_builder.h" @@ -191,11 +192,11 @@ void CompileSendTPUEmbeddingGradients( std::vector gradient_shapes; auto builder = ctx->builder(); gradient_shapes.reserve(gradients.size()); - for (xla::XlaOp op : gradients) { - // Gradient layout information is added by XLA, so we can just create - // default layout information. - xla::Shape gradient_shape = builder->GetShape(op).value(); - xla::LayoutUtil::SetToDefaultLayout(&gradient_shape); + for (int i = 0; i < gradients.size(); ++i) { + DataType dtype = ctx->input_type(i); + xla::Shape gradient_shape; + OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, tf_gradient_shapes[i], + &gradient_shape)); gradient_shapes.push_back(gradient_shape); }