[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #780 from khanhlvg:python-task-api-error-handling
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 440587605
  • Loading branch information
tflite-support-robot committed Apr 9, 2022
2 parents ce41685 + a745c84 commit 1ab917f
Show file tree
Hide file tree
Showing 15 changed files with 89 additions and 137 deletions.
21 changes: 4 additions & 17 deletions tensorflow_lite_support/python/task/audio/audio_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,9 @@ def create_from_file(cls, file_path: str) -> "AudioEmbedder":
`AudioEmbedder` object that's created from `options`.
Raises:
status.StatusNotOk if failed to create `AudioEmbedder` object from the
provided file such as invalid file.
RuntimeError if failed to create `AudioEmbedder` object from the provided
file such as invalid file.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.
base_options = _BaseOptions(file_name=file_path)
options = AudioEmbedderOptions(base_options=base_options)
return cls.create_from_options(options)
Expand All @@ -80,13 +76,9 @@ def create_from_options(cls,
`AudioEmbedder` object that's created from `options`.
Raises:
status.StatusNotOk if failed to create `AudioEmbedder` object from
RuntimeError if failed to create `AudioEmbedder` object from
`AudioEmbedderOptions` such as missing the model.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.
embedder = _CppAudioEmbedder.create_from_options(options.base_options,
options.embedding_options)
return cls(options, embedder)
Expand Down Expand Up @@ -122,13 +114,8 @@ def embed(self,
embedding result.
Raises:
status.StatusNotOk if failed to get the embedding vector.
RuntimeError if failed to get the embedding vector.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.

return self._embedder.embed(
_CppAudioBuffer(audio.buffer, audio.buffer_size, audio.format))

Expand Down
3 changes: 1 addition & 2 deletions tensorflow_lite_support/python/task/audio/core/pybinds/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ pybind_extension(
],
module_name = "_pywrap_audio_buffer",
deps = [
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/audio/core:audio_buffer",
"//tensorflow_lite_support/cc/task/audio/utils:audio_utils",
"//tensorflow_lite_support/python/task/core/pybinds:task_utils",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:status_casters",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ limitations under the License.
==============================================================================*/
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h"
#include "tensorflow_lite_support/cc/task/audio/utils/audio_utils.h"
#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"

namespace tflite {
namespace task {
Expand All @@ -40,15 +39,16 @@ PYBIND11_MODULE(_pywrap_audio_buffer, m) {
.def_readonly("sample_rate", &AudioBuffer::AudioFormat::sample_rate);

py::class_<AudioBuffer>(m, "AudioBuffer", py::buffer_protocol())
.def(py::init([](py::buffer buffer, const int sample_count,
const AudioBuffer::AudioFormat& audio_format) {
py::buffer_info info = buffer.request();
.def(py::init([](
py::buffer buffer, const int sample_count,
const AudioBuffer::AudioFormat& audio_format)
-> std::unique_ptr<AudioBuffer> {
py::buffer_info info = buffer.request();

// TODO(b/220931229): Change this initializer to use AudioBuffer::Create
// and raise RuntimeError if initialization failed.
return absl::make_unique<AudioBuffer>(static_cast<float*>(info.ptr),
sample_count, audio_format);
}))
auto audio_buffer = AudioBuffer::Create(
static_cast<float*>(info.ptr), sample_count, audio_format);
return core::get_value(audio_buffer);
}))
.def_property_readonly("audio_format", &AudioBuffer::GetAudioFormat)
.def_property_readonly("buffer_size", &AudioBuffer::GetBufferSize)
.def_property_readonly("float_buffer", [](AudioBuffer& self) {
Expand All @@ -62,12 +62,13 @@ PYBIND11_MODULE(_pywrap_audio_buffer, m) {

m.def("LoadAudioBufferFromFile",
[](const std::string& wav_file, int buffer_size,
py::buffer buffer) -> tflite::support::StatusOr<AudioBuffer> {
py::buffer buffer) -> AudioBuffer {
py::buffer_info info = buffer.request();

return LoadAudioBufferFromFile(
auto audio_buffer = LoadAudioBufferFromFile(
wav_file, buffer_size,
static_cast<std::vector<float>*>(info.ptr));
return core::get_value(audio_buffer);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,8 @@ def create_from_wav_file(cls, file_name: str,
`TensorAudio` object.
Raises:
status.StatusNotOk if the audio file can't be decoded.
RuntimeError if the audio file can't be decoded.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.
audio = _LoadAudioBufferFromFile(file_name, sample_count,
np.zeros([sample_count]))
tensor = TensorAudio(audio.audio_format, audio.buffer_size)
Expand Down
3 changes: 1 addition & 2 deletions tensorflow_lite_support/python/task/audio/pybinds/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ pybind_extension(
],
module_name = "_pywrap_audio_embedder",
deps = [
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/audio:audio_embedder",
"//tensorflow_lite_support/cc/task/audio/core:audio_buffer",
"//tensorflow_lite_support/cc/task/processor/proto:embedding_cc_proto",
"//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto",
"//tensorflow_lite_support/python/task/core/pybinds:task_utils",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ limitations under the License.
==============================================================================*/

#include "pybind11/pybind11.h"
#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
#include "tensorflow_lite_support/cc/task/audio/audio_embedder.h"
#include "tensorflow_lite_support/cc/task/audio/core/audio_buffer.h"
#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"
Expand All @@ -34,7 +33,6 @@ using CppBaseOptions = ::tflite::task::core::BaseOptions;
PYBIND11_MODULE(_pywrap_audio_embedder, m) {
// python wrapper for C++ AudioEmbedder class which shouldn't be directly used
// by the users.
pybind11::google::ImportStatusModule();
pybind11_protobuf::ImportNativeProtoCasters();

py::class_<AudioEmbedder>(m, "AudioEmbedder")
Expand All @@ -48,10 +46,21 @@ PYBIND11_MODULE(_pywrap_audio_embedder, m) {

options.set_allocated_base_options(cpp_base_options.release());
options.add_embedding_options()->CopyFrom(embedding_options);
return AudioEmbedder::CreateFromOptions(options);
auto embedder = AudioEmbedder::CreateFromOptions(options);
return core::get_value(embedder);
})
.def_static("cosine_similarity", &AudioEmbedder::CosineSimilarity)
.def("embed", &AudioEmbedder::Embed)
.def_static("cosine_similarity",
[](const processor::FeatureVector& u,
const processor::FeatureVector& v) -> double {
auto similarity = AudioEmbedder::CosineSimilarity(u, v);
return core::get_value(similarity);
})
.def("embed",
[](AudioEmbedder& self,
const AudioBuffer& audio_buffer) -> processor::EmbeddingResult {
auto embedding_result = self.Embed(audio_buffer);
return core::get_value(embedding_result);
})
.def("get_embedding_dimension", &AudioEmbedder::GetEmbeddingDimension)
.def("get_number_of_output_layers",
&AudioEmbedder::GetNumberOfOutputLayers)
Expand Down
17 changes: 3 additions & 14 deletions tensorflow_lite_support/python/task/vision/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,9 @@ def create_from_file(cls, file_path: str) -> "ImageClassifier":
Returns:
`ImageClassifier` object that's created from the model file.
Raises:
status.StatusNotOk if failed to create `ImageClassifier` object from the
RuntimeError if failed to create `ImageClassifier` object from the
provided file such as invalid file.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.
base_options = _BaseOptions(file_name=file_path)
options = ImageClassifierOptions(base_options=base_options)
return cls.create_from_options(options)
Expand All @@ -76,13 +72,9 @@ def create_from_options(cls,
Returns:
`ImageClassifier` object that's created from `options`.
Raises:
status.StatusNotOk if failed to create `ImageClassifier` object from
RuntimeError if failed to create `ImageClassifier` object from
`ImageClassifierOptions` such as missing the model.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.
classifier = _CppImageClassifier.create_from_options(
options.base_options, options.classification_options)
return cls(options, classifier)
Expand All @@ -103,10 +95,7 @@ def classify(
Returns:
classification result.
Raises:
status.StatusNotOk if failed to get the feature vector. Need to import the
module to catch this error: `from pybind11_abseil
import status`, see
https://github.com/pybind/pybind11_abseil#abslstatusor.
RuntimeError if failed to get the feature vector.
"""
image_data = image_utils.ImageData(image.buffer)
if bounding_box is None:
Expand Down
21 changes: 5 additions & 16 deletions tensorflow_lite_support/python/task/vision/object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ def create_from_file(cls, file_path: str) -> "ObjectDetector":
Returns:
`ObjectDetector` object that's created from the model file.
Raises:
status.StatusNotOk if failed to create `ObjectDetector` object from the
provided file such as invalid file.
RuntimeError if failed to create `ObjectDetector` object from the provided
file such as invalid file.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.
base_options = _BaseOptions(file_name=file_path)
options = ObjectDetectorOptions(base_options=base_options)
return cls.create_from_options(options)
Expand All @@ -76,13 +72,9 @@ def create_from_options(cls,
Returns:
`ObjectDetector` object that's created from `options`.
Raises:
status.StatusNotOk if failed to create `ObjectDetector` object from
`ObjectDetectorOptions` such as missing the model.
RuntimeError if failed to create `ObjectDetector` object from
`ObjectDetectorOptions` such as missing the model.
"""
# TODO(b/220931229): Raise RuntimeError instead of status.StatusNotOk.
# Need to import the module to catch this error:
# `from pybind11_abseil import status`
# see https://github.com/pybind/pybind11_abseil#abslstatusor.
detector = _CppObjectDetector.create_from_options(options.base_options,
options.detection_options)
return cls(options, detector)
Expand All @@ -97,10 +89,7 @@ def detect(self,
Returns:
detection result.
Raises:
status.StatusNotOk if failed to get the feature vector. Need to import the
module to catch this error: `from pybind11_abseil
import status`, see
https://github.com/pybind/pybind11_abseil#abslstatusor.
RuntimeError if failed to get the feature vector.
"""
image_data = image_utils.ImageData(image.buffer)

Expand Down
4 changes: 0 additions & 4 deletions tensorflow_lite_support/python/task/vision/pybinds/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,13 @@ pybind_extension(
],
module_name = "_pywrap_image_classifier",
deps = [
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/processor/proto:bounding_box_cc_proto",
"//tensorflow_lite_support/cc/task/processor/proto:classification_options_cc_proto",
"//tensorflow_lite_support/cc/task/processor/proto:classifications_cc_proto",
"//tensorflow_lite_support/cc/task/vision:image_classifier",
"//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
"//tensorflow_lite_support/python/task/core/pybinds:task_utils",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
],
)
Expand All @@ -52,14 +50,12 @@ pybind_extension(
],
module_name = "_pywrap_object_detector",
deps = [
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/processor/proto:detection_options_cc_proto",
"//tensorflow_lite_support/cc/task/processor/proto:detections_cc_proto",
"//tensorflow_lite_support/cc/task/vision:object_detector",
"//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils",
"//tensorflow_lite_support/python/task/core/pybinds:task_utils",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:status_casters",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ limitations under the License.
==============================================================================*/

#include "pybind11/pybind11.h"
#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/processor/proto/bounding_box.pb.h"
#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h"
#include "tensorflow_lite_support/cc/task/processor/proto/classifications.pb.h"
#include "tensorflow_lite_support/cc/task/processor/proto/classification_options.pb.h"
#include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"
Expand All @@ -37,7 +35,6 @@ using CppBaseOptions = ::tflite::task::core::BaseOptions;
PYBIND11_MODULE(_pywrap_image_classifier, m) {
// python wrapper for C++ ImageClassifier class which shouldn't be directly
// used by the users.
pybind11::google::ImportStatusModule();
pybind11_protobuf::ImportNativeProtoCasters();

py::class_<ImageClassifier>(m, "ImageClassifier")
Expand Down Expand Up @@ -66,43 +63,42 @@ PYBIND11_MODULE(_pywrap_image_classifier, m) {
options.mutable_class_name_blacklist()->CopyFrom(
classification_options.class_name_denylist());

return ImageClassifier::CreateFromOptions(options);
auto classifier = ImageClassifier::CreateFromOptions(options);
return core::get_value(classifier);
})
.def("classify",
[](ImageClassifier& self, const ImageData& image_data)
-> tflite::support::StatusOr<processor::ClassificationResult> {
ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> frame_buffer,
CreateFrameBufferFromImageData(image_data));
ASSIGN_OR_RETURN(ClassificationResult vision_classification_result,
self.Classify(*frame_buffer));

-> processor::ClassificationResult {
auto frame_buffer = CreateFrameBufferFromImageData(image_data);
auto vision_classification_result = self.Classify(
*core::get_value(frame_buffer));
// Convert from vision::ClassificationResult to
// processor::ClassificationResult as required by the Python layer.
processor::ClassificationResult classification_result;
classification_result.ParseFromString(
vision_classification_result.SerializeAsString());
classification_result.ParseFromString(
core::get_value(vision_classification_result)
.SerializeAsString());
return classification_result;
})
.def("classify",
[](ImageClassifier& self, const ImageData& image_data,
const processor::BoundingBox& bounding_box)
-> tflite::support::StatusOr<processor::ClassificationResult> {
-> processor::ClassificationResult {
// Convert from processor::BoundingBox to vision::BoundingBox as
// the latter is used in the C++ layer.
BoundingBox vision_bounding_box;
vision_bounding_box.ParseFromString(
bounding_box.SerializeAsString());
ASSIGN_OR_RETURN(std::unique_ptr<FrameBuffer> frame_buffer,
CreateFrameBufferFromImageData(image_data));
ASSIGN_OR_RETURN(
ClassificationResult vision_classification_result,
self.Classify(*frame_buffer, vision_bounding_box));

auto frame_buffer = CreateFrameBufferFromImageData(image_data);
auto vision_classification_result = self.Classify(
*core::get_value(frame_buffer), vision_bounding_box);
// Convert from vision::ClassificationResult to
// processor::ClassificationResult as required by the Python layer.
processor::ClassificationResult classification_result;
classification_result.ParseFromString(
vision_classification_result.SerializeAsString());
classification_result.ParseFromString(
core::get_value(vision_classification_result)
.SerializeAsString());
return classification_result;
});
}
Expand Down
Loading

0 comments on commit 1ab917f

Please sign in to comment.