[go: nahoru, domu]

Skip to content

Commit

Permalink
Updated ImageClassifier error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kinarr committed Apr 8, 2022
1 parent 9761582 commit 7884c9c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 25 deletions.
13 changes: 5 additions & 8 deletions tensorflow_lite_support/python/task/vision/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def create_from_file(cls, file_path: str) -> "ImageClassifier":
Returns:
`ImageClassifier` object that's created from the model file.
Raises:
status.StatusNotOk if failed to create `ImageClassifier` object from the
provided file such as invalid file.
RuntimeError if failed to create `ImageClassifier` object from the provided
file such as invalid file.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
Expand All @@ -76,8 +76,8 @@ def create_from_options(cls,
Returns:
`ImageClassifier` object that's created from `options`.
Raises:
status.StatusNotOk if failed to create `ImageClassifier` object from
`ImageClassifierOptions` such as missing the model.
RuntimeError if failed to create `ImageClassifier` object from
`ImageClassifierOptions` such as missing the model.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
Expand All @@ -103,10 +103,7 @@ def classify(
Returns:
classification result.
Raises:
status.StatusNotOk if failed to get the feature vector. Need to import the
module to catch this error: `from pybind11_abseil
import status`, see
https://github.com/pybind/pybind11_abseil#abslstatusor.
RuntimeError if failed to get the feature vector.
"""
image_data = image_utils.ImageData(image.buffer)
if bounding_box is None:
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_lite_support/python/task/vision/pybinds/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ pybind_extension(
"//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
"//tensorflow_lite_support/python/task/core/pybinds:task_utils",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
],
)
Expand All @@ -36,6 +35,7 @@ pybind_extension(
deps = [
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/processor/proto:bounding_box_cc_proto",
"//tensorflow_lite_support/cc/task/processor/proto:classifications_cc_proto",
"//tensorflow_lite_support/cc/task/processor/proto:classification_options_cc_proto",
"//tensorflow_lite_support/cc/task/vision:image_classifier",
"//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h"
#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h"
#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h"
#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
Expand Down Expand Up @@ -65,27 +66,43 @@ PYBIND11_MODULE(_pywrap_image_classifier, m) {
options.mutable_class_name_blacklist()->CopyFrom(
classification_options.class_name_denylist());

return ImageClassifier::CreateFromOptions(options);
auto classifier = ImageClassifier::CreateFromOptions(options);
return core::get_value(classifier);
})
.def("classify",
[](ImageClassifier& self, const ImageData& image_data)
-> tflite::support::StatusOr<ClassificationResult> {
ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> frame_buffer,
CreateFrameBufferFromImageData(image_data));
return self.Classify(*frame_buffer);
-> processor::ClassificationResult {
auto frame_buffer = CreateFrameBufferFromImageData(image_data);
auto vision_classification_result = self.Classify(
*core::get_value(frame_buffer));
// Convert from vision::ClassificationResult to
// processor::ClassificationResult
processor::ClassificationResult classification_result;
classification_result.ParseFromString(
core::get_value(vision_classification_result)
.SerializeAsString());
return classification_result;
})
.def("classify",
[](ImageClassifier& self, const ImageData& image_data,
const processor::BoundingBox& bounding_box)
-> tflite::support::StatusOr<ClassificationResult> {
-> processor::ClassificationResult {
// Convert from processor::BoundingBox to vision::BoundingBox as
// the later is used in the C++ layer.
BoundingBox vision_bounding_box;
vision_bounding_box.ParseFromString(
bounding_box.SerializeAsString());
ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> frame_buffer,
CreateFrameBufferFromImageData(image_data));
return self.Classify(*frame_buffer, vision_bounding_box);

auto frame_buffer = CreateFrameBufferFromImageData(image_data);
auto vision_classification_result = self.Classify(
*core::get_value(frame_buffer), vision_bounding_box);
// Convert from vision::ClassificationResult to
// processor::ClassificationResult
processor::ClassificationResult classification_result;
classification_result.ParseFromString(
core::get_value(vision_classification_result)
.SerializeAsString());
return classification_result;
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,9 @@ def test_create_from_options_succeeds_with_valid_model_path(self):
def test_create_from_options_fails_with_invalid_model_path(self):
# Invalid empty model path.
with self.assertRaisesRegex(
Exception,
r'INVALID_ARGUMENT: ExternalFile must specify at least one of '
r"'file_content', 'file_name' or 'file_descriptor_meta'. "
r"\[tflite::support::TfLiteSupportStatus='2']"):
RuntimeError,
r"ExternalFile must specify at least one of 'file_content', "
r"'file_name' or 'file_descriptor_meta'."):
base_options = _BaseOptions(file_name='')
options = _ImageClassifierOptions(base_options=base_options)
_ImageClassifier.create_from_options(options)
Expand Down Expand Up @@ -293,9 +292,8 @@ def test_combined_allowlist_and_denylist(self):
# Fails with combined allowlist and denylist
with self.assertRaisesRegex(
Exception,
r'INVALID_ARGUMENT: `class_name_whitelist` and `class_name_blacklist` '
r'are mutually exclusive options. '
r"\[tflite::support::TfLiteSupportStatus='2'\]"):
r"`class_name_whitelist` and `class_name_blacklist` are mutually "
r"exclusive options."):
base_options = _BaseOptions(file_name=self.model_path)
classification_options = classification_options_pb2.ClassificationOptions(
class_name_allowlist=['foo'], class_name_denylist=['bar'])
Expand Down

0 comments on commit 7884c9c

Please sign in to comment.