[go: nahoru, domu]

Skip to content

Commit

Permalink
[XlaCallModule] Fixes for serialization version 9.
Browse files Browse the repository at this point in the history
In version 9, the main function of a serialized module may contain token arguments and outputs. Those do not correspond to actual XlaCallModule op inputs and outputs.

In cl/577032011 we had adjusted the input_shapes for the call to RefineDynamicShapes from xla_call_module_op. Here we move the adjustment to
input_shapes inside the RefineDynamicShapes, so that it takes effect for all
call sites, including those from shape_inference.

PiperOrigin-RevId: 597445704
  • Loading branch information
gnecula authored and tensorflower-gardener committed Jan 11, 2024
1 parent 1e2f8e8 commit 92d5caa
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 31 deletions.
41 changes: 28 additions & 13 deletions tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
Expand Down Expand Up @@ -1235,8 +1236,8 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) {
/*num_invocation_args=*/op.getArgs().size(),
op.getHasTokenInputOutput());
if (!l.ok()) {
LLVM_DEBUG(llvm::dbgs() << "Parsing error in XlaCallModule: "
<< l.status().ToString() << "\n");
llvm::errs() << "Parsing error in XlaCallModule: "
<< l.status().ToString() << "\n";
return false;
}

Expand All @@ -1255,30 +1256,44 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) {

tsl::Status status = loader->RefineDynamicShapes(input_shapes);
if (!status.ok()) {
LLVM_DEBUG(llvm::dbgs() << "Failed during XlaCallModule shape refinement: "
<< status.ToString());
llvm::errs() << "Failed during XlaCallModule shape refinement: "
<< status.ToString();
// RefineDynamicShapes returns ok only when it produces full static shapes.
// It may partially succeed by producing RankedTensor shapes with dynamic
// dimensions. Such info is still useful for the downstream. We don't need
// to abort here.
// TODO(b/316639984): improve RefineDynamicShapes return values to include
// these info.
return false;
}
mlir::ResultRange op_results = op.getResults();
// The main_outputs may include tokens that are not among the op_results;
mlir::TypeRange main_output_types = loader->OutputTypes();
int nr_main_token_outputs =
llvm::count_if(main_output_types, tensorflow::IsTokenType);
if (op_results.size() != main_output_types.size() - nr_main_token_outputs) {
llvm::errs() << "XlaCallModule has " << op_results.size()
<< " but the main function has "
<< main_output_types.size() - nr_main_token_outputs
<< " non-token ouputs";
return false;
}

bool changed = false;
for (auto [result, type] :
llvm::zip(op.getResults(), loader->OutputTypes())) {
auto ranked = type.dyn_cast<RankedTensorType>();
if (ranked == nullptr) {
LLVM_DEBUG(llvm::dbgs()
<< "Unsupported XlaCallModule result type: " << type);
continue;
int next_op_result = 0;
for (auto output_type : main_output_types) {
if (tensorflow::IsTokenType(output_type)) continue;
auto output_type_ranked = output_type.dyn_cast<RankedTensorType>();
if (output_type_ranked == nullptr) {
llvm::errs() << "Unsupported XlaCallModule result type: " << output_type
<< "\n";
return false;
}
auto result = op_results[next_op_result++];

// Build a new type object from `type` and `elem_type`. `type` is owned by
// `xla_call_module_context_` and should not be mixed with op's context.
auto new_type = RankedTensorType::get(
ranked.getShape(), getElementTypeOrSelf(result.getType()));
output_type_ranked.getShape(), getElementTypeOrSelf(result.getType()));

changed = RefineResultType(op, result, new_type) || changed;
}
Expand Down
32 changes: 24 additions & 8 deletions tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,30 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes(
return tsl::OkStatus();
}
}
// Add the tokens to the input_shapes. Starting with version 9, the main
// function may take token arguments that do not correspond with op inputs.
int nr_inputs = NrInputs();
int nr_expected_tokens = llvm::count_if(InputTypes(), IsTokenType);
if (input_shapes.size() != NrInputs() - nr_expected_tokens) {
return absl::InvalidArgumentError(absl::StrCat(
"XlaCallModule RefineDynamicShapes called with ", input_shapes.size(),
"input shapes, but the main function takes ",
NrInputs() - nr_expected_tokens, " non-token arguments"));
}

mlir::Block &main_body = main_.front();
int non_dimension_arguments = input_shapes.size();

mlir::Builder builder(module_->getContext());
std::vector<mlir::Type> static_array_input_types(non_dimension_arguments);
for (int i = 0, end = non_dimension_arguments; i < end; ++i) {
const xla::Shape &xla_shape = input_shapes[i];
if (xla_shape.IsToken()) {
std::vector<mlir::Type> static_array_input_types(nr_inputs);
int next_actual_input = 0;
for (int i = 0, end = nr_inputs; i < end; ++i) {
mlir::Type arg_type = main_body.getArgument(i).getType();
if (IsTokenType(arg_type)) {
static_array_input_types[i] = mlir::stablehlo::TokenType::get(context_);
VLOG(3) << "XlaCallModule static array input type #" << i << ": "
<< mlir::debugString(static_array_input_types[i])
<< " for argument type " << mlir::debugString(arg_type);
} else {
const xla::Shape &xla_shape = input_shapes[next_actual_input++];
std::vector<int64_t> xla_dimensions(xla_shape.dimensions().begin(),
xla_shape.dimensions().end());
TF_ASSIGN_OR_RETURN(
Expand All @@ -209,12 +222,15 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes(
// mlir::Type type,
// ConvertShapeToType<mlir::RankedTensorType>(xla_shape, builder));
VLOG(3) << "XlaCallModule static array input type #" << i << ": "
<< mlir::debugString(type);
<< mlir::debugString(type) << " for argument type "
<< mlir::debugString(arg_type);
mlir::TensorType arg_type =
main_body.getArgument(i).getType().dyn_cast<mlir::TensorType>();
if (arg_type == nullptr) {
return absl::InvalidArgumentError(absl::StrCat(
"Argument ", i, " passed to XlaCallModule is not a tensor"));
"Argument ", i, " passed to XlaCallModule is not a tensor, ",
"has type ",
mlir::debugString(main_body.getArgument(i).getType())));
}

if (arg_type.getElementType() != type.getElementType()) {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class XlaCallModuleLoader {
// then set them as the types of the function parameters, and run StableHLO
// shape refinement to specialize all dynamic shapes in the StableHLO program
// to static shapes.
// Starting with version 9, the "main" function may accept token arguments.
// The input_shapes includes only the non-token arguments.
//
// This method accepts a list of `llvm::ArrayRef` instead of `mlir::Type`.
// This is to prevent callers from accidentally passing `mlir::Type` owned by
Expand Down
15 changes: 5 additions & 10 deletions tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,10 @@ class XlaCallModuleOp : public XlaOpKernel {
xla::XlaBuilder *const b = ctx->builder();

std::vector<xla::Shape> input_shapes;
int next_actual_input = 0;
for (mlir::Type inputType : loader_->InputTypes()) {
if (IsTokenType(inputType)) {
input_shapes.push_back(xla::ShapeUtil::MakeTokenShape());
} else {
auto shape = ctx->InputXlaShape(next_actual_input++);
OP_REQUIRES_OK(ctx, shape.status());
input_shapes.push_back(*std::move(shape));
}
for (int i = 0; i < ctx->num_inputs(); ++i) {
auto shape = ctx->InputXlaShape(i);
OP_REQUIRES_OK(ctx, shape.status());
input_shapes.push_back(*std::move(shape));
}
OP_REQUIRES_OK(ctx, loader_->RefineDynamicShapes(input_shapes));
OP_REQUIRES_OK(ctx, loader_->ValidateStaticShapes());
Expand All @@ -263,7 +258,7 @@ class XlaCallModuleOp : public XlaOpKernel {
}

std::vector<xla::XlaOp> inputs;
next_actual_input = 0;
int next_actual_input = 0;
for (mlir::Type inputType : loader_->InputTypes()) {
if (IsTokenType(inputType)) {
if (token_input.IsUninitialized()) {
Expand Down

0 comments on commit 92d5caa

Please sign in to comment.