[go: nahoru, domu]

Skip to content

Commit

Permalink
Replace tsl errors with absl errors.
Browse files Browse the repository at this point in the history
This improves source location reporting (see go/drop-tensorflow-status).

PiperOrigin-RevId: 534693287
  • Loading branch information
gnecula authored and tensorflower-gardener committed May 24, 2023
1 parent c61d04f commit a935a88
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 47 deletions.
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ cc_library(
"//tensorflow/tsl/platform:errors",
"//tensorflow/tsl/platform:regexp",
"//tensorflow/tsl/platform:statusor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
Expand Down Expand Up @@ -433,6 +435,7 @@ tf_kernel_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"//tensorflow/core/tpu:tpu_defs",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
],
Expand Down
78 changes: 40 additions & 38 deletions tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
Expand Down Expand Up @@ -94,27 +96,27 @@ tsl::StatusOr<mlir::Value> ComputeDimensionValue(
int arg_idx, arg_axis_idx;
if (!RE2::FullMatch(dim_arg_spec, *dim_arg_spec_re, &arg_idx,
&arg_axis_idx)) {
return tsl::errors::InvalidArgument("Syntax error in dim_args_spec '",
dim_arg_spec, "'");
return absl::InvalidArgumentError(
absl::StrCat("Syntax error in dim_args_spec '", dim_arg_spec, "'"));
}
if (arg_idx < 0 || arg_idx >= arguments.size()) {
return tsl::errors::InvalidArgument(
return absl::InvalidArgumentError(absl::StrCat(
"Invalid argument index ", arg_idx,
" when the number of non-dimension arguments is ", arguments.size(),
" in dim_arg_spec '", dim_arg_spec, "'");
" in dim_arg_spec '", dim_arg_spec, "'"));
}
mlir::RankedTensorType arg_type =
arguments[arg_idx].getType().dyn_cast<mlir::RankedTensorType>();
if (!arg_type) {
return tsl::errors::InvalidArgument(
"Argument ", arg_idx, " referenced in dim_arg_spec '", dim_arg_spec,
"' does not have a RankedTensorType");
return absl::InvalidArgumentError(
absl::StrCat("Argument ", arg_idx, " referenced in dim_arg_spec '",
dim_arg_spec, "' does not have a RankedTensorType"));
}
if (arg_axis_idx < 0 || arg_axis_idx >= arg_type.getShape().size()) {
return tsl::errors::InvalidArgument(
return absl::InvalidArgumentError(absl::StrCat(
"Invalid axis index ", arg_axis_idx,
" when the rank of non-dimension argument ", arg_idx, " is ",
arg_type.getShape().size(), " in dim_arg_spec '", dim_arg_spec, "'");
arg_type.getShape().size(), " in dim_arg_spec '", dim_arg_spec, "'"));
}
mlir::Value val;
mlir::Type get_dim_type =
Expand All @@ -135,15 +137,15 @@ tsl::StatusOr<std::unique_ptr<XlaCallModuleLoader>> XlaCallModuleLoader::Create(
mlir::MLIRContext *context, int version, std::string module_str,
std::vector<std::string> dim_args_spec, int platform_index) {
if (version < VERSION_MINIMUM_SUPPORTED) {
return tsl::errors::InvalidArgument(
return absl::InvalidArgumentError(absl::StrCat(
"XlaCallModuleOp with version ", version,
" is not supported anymore. Must be >= ", VERSION_MINIMUM_SUPPORTED);
" is not supported anymore. Must be >= ", VERSION_MINIMUM_SUPPORTED));
}
if (version > VERSION_MAXIMUM_SUPPORTED) {
return tsl::errors::InvalidArgument(
"XlaCallModuleOp with version ", version,
" is not supported by this build. Must be <= ",
VERSION_MAXIMUM_SUPPORTED);
return absl::InvalidArgumentError(
absl::StrCat("XlaCallModuleOp with version ", version,
" is not supported by this build. Must be <= ",
VERSION_MAXIMUM_SUPPORTED));
}

if (version < VERSION_START_PLATFORMS) {
Expand Down Expand Up @@ -204,18 +206,18 @@ tsl::Status XlaCallModuleLoader::AddMainWrapper() {
mlir::func::FuncOp orig_main =
module_->lookupSymbol<mlir::func::FuncOp>("main");
if (!orig_main) {
return tsl::errors::InvalidArgument("Cannot find 'main' in module");
return absl::InvalidArgumentError("Cannot find 'main' in module");
}
int nr_platform_args = 0;
if (platform_index_ >= 0) {
nr_platform_args = 1;
}
if (orig_main.getNumArguments() <= nr_platform_args + nr_dim_args) {
return tsl::errors::InvalidArgument(
"The module should have ", nr_platform_args,
" platform index arguments and ", nr_dim_args,
" dimension arguments, but it ", "has only ",
orig_main.getNumArguments(), " total arguments");
return absl::InvalidArgumentError(
absl::StrCat("The module should have ", nr_platform_args,
" platform index arguments and ", nr_dim_args,
" dimension arguments, but it ", "has only ",
orig_main.getNumArguments(), " total arguments"));
}
mlir::Block &orig_main_body = orig_main.front();

Expand Down Expand Up @@ -250,18 +252,18 @@ tsl::Status XlaCallModuleLoader::AddMainWrapper() {
!arg_ranked_type.getShape().empty()) {
std::string argument_type =
(i < nr_platform_args) ? "platform index" : "dimension";
return tsl::errors::InvalidArgument(
return absl::InvalidArgumentError(absl::StrCat(
"Module argument at index ", i,
" should be a 0-dimensional integer-tensor ", argument_type,
" argument but has type ", mlir::debugString(arg_type));
" argument but has type ", mlir::debugString(arg_type)));
}
if (i < nr_platform_args) {
if (arg_ranked_type.getElementTypeBitWidth() != 32) {
return tsl::errors::InvalidArgument(
"Module argument at index ", i,
" should be a 0-dimensional 32-bit integer-tensor"
" platform index argument but has type ",
mlir::debugString(arg_type));
return absl::InvalidArgumentError(
absl::StrCat("Module argument at index ", i,
" should be a 0-dimensional 32-bit integer-tensor"
" platform index argument but has type ",
mlir::debugString(arg_type)));
}
call_args[i] = op_builder.create<mlir::stablehlo::ConstantOp>(
block_args[0].getLoc(),
Expand Down Expand Up @@ -296,14 +298,14 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes(
int nr_dim_args = dim_args_spec_.size();
int non_dimension_arguments = input_shapes.size();
if (non_dimension_arguments != main_body.getNumArguments()) {
return tsl::errors::InvalidArgument(
return absl::InvalidArgumentError(absl::StrCat(
"Incorrect number of arguments passed to XlaCallModule: ",
non_dimension_arguments, ". The module takes ",
main_body.getNumArguments() + nr_platform_args + nr_dim_args,
" arguments of which ", nr_platform_args,
" platform index arguments and ", nr_dim_args,
" dimension arguments. It must be called with ",
main_body.getNumArguments(), " arguments.");
main_body.getNumArguments(), " arguments."));
}

mlir::Builder builder(module_->getContext());
Expand Down Expand Up @@ -362,7 +364,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes(
XLA_VLOG_LINES(3,
absl::StrCat("XlaCallModule module with verification failed",
ModuleToString(*module_, VLOG_IS_ON(4))));
return tsl::errors::InvalidArgument("Module inlining failed");
return absl::InvalidArgumentError("Module inlining failed");
}
XLA_VLOG_LINES(
5, absl::StrCat(
Expand Down Expand Up @@ -395,7 +397,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes(
XLA_VLOG_LINES(
3, absl::StrCat("XlaCallModule module with verification failed: ",
ModuleToString(*module_, VLOG_IS_ON(4))));
return tsl::errors::InvalidArgument("Module verification failed");
return absl::InvalidArgumentError("Module verification failed");
}
mlir::PassManager pm(module_->getContext());
if (VLOG_IS_ON(5)) {
Expand All @@ -412,7 +414,7 @@ tsl::Status XlaCallModuleLoader::RefineDynamicShapes(
XLA_VLOG_LINES(
3, absl::StrCat("XlaCallModule module with verification failed: ",
ModuleToString(*module_, VLOG_IS_ON(4))));
return tsl::errors::InvalidArgument("Module shape refinement failed");
return absl::InvalidArgumentError("Module shape refinement failed");
}

XLA_VLOG_LINES(3, absl::StrCat("XlaCallModule module with refined shapes: ",
Expand Down Expand Up @@ -444,7 +446,7 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule(
}

if (!module_) {
return tsl::errors::InvalidArgument("Cannot deserialize computation");
return absl::InvalidArgumentError("Cannot deserialize computation");
}
XLA_VLOG_LINES(3, absl::StrCat("Parsed serialized module (version ", version,
", platform_index = ", platform_index_,
Expand All @@ -456,11 +458,11 @@ tsl::Status XlaCallModuleLoader::LoadAndPreprocessModule(
XLA_VLOG_LINES(
3, absl::StrCat("XlaCallModule module with verification failed: ",
ModuleToString(*module_, VLOG_IS_ON(4))));
return tsl::errors::InvalidArgument("Error verifying module");
return absl::InvalidArgumentError("Error verifying module");
}
main_ = module_->lookupSymbol<mlir::func::FuncOp>("main");
if (!main_) {
return tsl::errors::InvalidArgument("Cannot find 'main' in module");
return absl::InvalidArgumentError("Cannot find 'main' in module");
}

if (!dim_args_spec_.empty() || platform_index_ >= 0) {
Expand Down Expand Up @@ -504,9 +506,9 @@ tsl::Status XlaCallModuleLoader::ValidateModule() {
});

if (moduleHasUnsupportedDialects)
return tsl::errors::InvalidArgument("Module has unsupported dialects");
return absl::InvalidArgumentError("Module has unsupported dialects");
if (moduleHasDynamicShapes)
return tsl::errors::InvalidArgument("Module has dynamic shapes");
return absl::InvalidArgumentError("Module has dynamic shapes");
return tsl::OkStatus();
}

Expand Down
20 changes: 11 additions & 9 deletions tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "llvm/ADT/ArrayRef.h"
#include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h"
Expand Down Expand Up @@ -51,10 +53,10 @@ class XlaCallModuleOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dim_args_spec", &dim_args_spec));
OP_REQUIRES(ctx,
expected_output_shapes.size() == expected_output_dtypes.size(),
errors::InvalidArgument("The size of Sout (",
expected_output_shapes.size(),
") must match the size of Tout (",
expected_output_dtypes.size(), ")"));
absl::InvalidArgumentError(absl::StrCat(
"The size of Sout (", expected_output_shapes.size(),
") must match the size of Tout (",
expected_output_dtypes.size(), ")")));
std::vector<string> platforms;
// Index in platforms of the current platform, or -1 if module does not take
// a platform index arg.
Expand All @@ -73,23 +75,23 @@ class XlaCallModuleOp : public XlaOpKernel {
current_platform = "ROCM";
#else
OP_REQUIRES(ctx, false,
errors::Unimplemented("CUDA or ROCM build required"));
absl::UnimplementedError("CUDA or ROCM build required"));
#endif
} else if (current_device_type == DEVICE_TPU_XLA_JIT) {
current_platform = "TPU";
} else {
OP_REQUIRES(ctx, false,
errors::Unimplemented("Unexpected device type ",
current_device_type));
absl::UnimplementedError(absl::StrCat(
"Unexpected device type ", current_device_type)));
}
VLOG(3) << "Initialized XlaCallModuleOp on " << current_platform;
auto found_platform =
std::find(platforms.begin(), platforms.end(), current_platform);
OP_REQUIRES(ctx, found_platform != platforms.end(),
errors::NotFound(
absl::NotFoundError(absl::StrCat(
"The current platform ", current_platform,
" is not among the platforms required by the module: [",
absl::StrJoin(platforms, ", "), "]"));
absl::StrJoin(platforms, ", "), "]")));
// We only use a platform index arguments if we support at least 2
// platforms.
if (platforms.size() > 1) {
Expand Down

0 comments on commit a935a88

Please sign in to comment.