[go: nahoru, domu]

Skip to content

Commit

Permalink
Reverts changelist 641306427
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642451001
  • Loading branch information
tensorflower-gardener committed Jun 12, 2024
1 parent c9e4c40 commit e3ac39a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 30 deletions.
2 changes: 1 addition & 1 deletion tensorflow/lite/delegates/gpu/common/model_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ TEST(CastOperationParserTest, TestIsSupported) {
->IsSupported(context.get(), context->node(), context->registration())
.ok());

context->tensor(1)->type = kTfLiteInt64;
context->tensor(1)->type = kTfLiteInt8;
context->tensor(2)->type = kTfLiteFloat32;
EXPECT_FALSE(
parser
Expand Down
1 change: 0 additions & 1 deletion tensorflow/lite/tools/versioning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ cc_library(
":op_signature",
"//tensorflow/lite:builtin_op_data",
"//tensorflow/lite:builtin_ops",
"//tensorflow/lite/core/c:c_api_types",
"//tensorflow/lite/core/c:common",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/status",
Expand Down
43 changes: 15 additions & 28 deletions tensorflow/lite/tools/versioning/gpu_compatibility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include "absl/strings/str_join.h"
#include "tensorflow/lite/builtin_op_data.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/tools/versioning/op_signature.h"

namespace tflite {
Expand Down Expand Up @@ -517,41 +516,29 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig) {
return absl::OkStatus();
}

case kTfLiteBuiltinCast: {
case kTfLiteBuiltinCast:
RETURN_IF_ERROR(CheckInputsOutputs(op_sig,
/*required_runtime_inputs=*/1,
/*required_outputs=*/1));
bool input_type_is_supported = false;
bool output_type_is_supported = false;
if (op_sig.inputs.at(0).type == kTfLiteBool ||
op_sig.inputs.at(0).type == kTfLiteFloat32 ||
op_sig.inputs.at(0).type == kTfLiteInt8 ||
op_sig.inputs.at(0).type == kTfLiteUInt8 ||
op_sig.inputs.at(0).type == kTfLiteInt16 ||
op_sig.inputs.at(0).type == kTfLiteUInt16 ||
op_sig.inputs.at(0).type == kTfLiteInt32 ||
op_sig.inputs.at(0).type == kTfLiteUInt32) {
input_type_is_supported = true;
}
if (op_sig.outputs.at(0).type == kTfLiteBool ||
op_sig.outputs.at(0).type == kTfLiteFloat32 ||
op_sig.outputs.at(0).type == kTfLiteInt8 ||
op_sig.outputs.at(0).type == kTfLiteUInt8 ||
op_sig.outputs.at(0).type == kTfLiteInt16 ||
op_sig.outputs.at(0).type == kTfLiteUInt16 ||
op_sig.outputs.at(0).type == kTfLiteInt32 ||
op_sig.outputs.at(0).type == kTfLiteUInt32) {
output_type_is_supported = true;
}

if (input_type_is_supported && output_type_is_supported) {
if (op_sig.inputs.at(0).type == kTfLiteBool &&
(op_sig.outputs.at(0).type == kTfLiteFloat16 ||
op_sig.outputs.at(0).type == kTfLiteFloat32)) {
return absl::OkStatus();
}
} else if ((op_sig.inputs.at(0).type == kTfLiteFloat16 ||
op_sig.inputs.at(0).type == kTfLiteFloat32) &&
op_sig.outputs.at(0).type == kTfLiteBool) {
return absl::OkStatus();
} else if ((op_sig.inputs.at(0).type == kTfLiteFloat32 ||
op_sig.inputs.at(0).type == kTfLiteInt32) &&
(op_sig.outputs.at(0).type == kTfLiteFloat32 ||
op_sig.outputs.at(0).type == kTfLiteInt32)) {
return absl::OkStatus();
} else {
return absl::UnimplementedError(absl::StrCat(
"Not supported Cast case. Input type: ",
TfLiteTypeGetName(op_sig.inputs.at(0).type), " and output type: ",
TfLiteTypeGetName(op_sig.outputs.at(0).type)));
}
}

case kTfLiteBuiltinConcatenation: {
const TfLiteConcatenationParams* tf_options;
Expand Down

0 comments on commit e3ac39a

Please sign in to comment.