[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #846 from khanhlvg:ios-audio-classifier-tests
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 459284920
  • Loading branch information
tflite-support-robot committed Jul 6, 2022
2 parents 5ad9c54 + 5384a48 commit 9023f40
Show file tree
Hide file tree
Showing 9 changed files with 407 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@ objc_library(
"//tensorflow_lite_support/ios/task/audio/core:TFLAudioFormat",
"//tensorflow_lite_support/ios/task/audio/core:TFLFloatBuffer",
"//tensorflow_lite_support/ios/task/audio/core:TFLRingBuffer",
"//third_party/apple_frameworks:AVFoundation",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ objc_library(
deps = [
"//tensorflow_lite_support/ios:TFLCommon",
"//tensorflow_lite_support/ios:TFLCommonUtils",
"//tensorflow_lite_support/ios/task/audio/core:TFLFloatBuffer",
"//tensorflow_lite_support/ios/task/audio/core:TFLRingBuffer",
"//tensorflow_lite_support/ios/task/audio/core/audio_record:TFLAudioRecord",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#import <Foundation/Foundation.h>
#import "tensorflow_lite_support/ios/task/audio/core/audio_record/sources/TFLAudioRecord.h"
#import "tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h"
#import "tensorflow_lite_support/ios/task/audio/core/sources/TFLFloatBuffer.h"

NS_ASSUME_NONNULL_BEGIN

Expand Down Expand Up @@ -70,7 +70,7 @@ NS_SWIFT_NAME(AudioTensor)
* @return A boolean indicating if the load operation succeded.
*/
- (BOOL)loadAudioRecord:(TFLAudioRecord *)audioRecord
withError:(NSError **)error NS_SWIFT_NAME(loadAudioRecord(audioRecord:));
withError:(NSError **)error NS_SWIFT_NAME(load(audioRecord:));

/**
* This function loads the internal buffer of `TFLAudioTensor` with the provided buffer.
Expand All @@ -90,9 +90,9 @@ NS_SWIFT_NAME(AudioTensor)
* @return A boolean indicating if the load operation succeded.
*/
- (BOOL)loadBuffer:(TFLFloatBuffer *)buffer
offset:(NSInteger)offset
size:(NSInteger)size
error:(NSError **)error;
offset:(NSUInteger)offset
size:(NSUInteger)size
error:(NSError **)error NS_SWIFT_NAME(load(buffer:offset:size:));

@end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#import "tensorflow_lite_support/ios/task/audio/core/audio_tensor/sources/TFLAudioTensor.h"
#import "tensorflow_lite_support/ios/sources/TFLCommon.h"
#import "tensorflow_lite_support/ios/sources/TFLCommonUtils.h"
#import "tensorflow_lite_support/ios/task/audio/core/sources/TFLRingBuffer.h"

@implementation TFLAudioTensor {
TFLRingBuffer *_ringBuffer;
Expand All @@ -32,8 +33,8 @@ - (instancetype)initWithAudioFormat:(TFLAudioFormat *)format sampleCount:(NSUInt
}

- (BOOL)loadBuffer:(TFLFloatBuffer *)buffer
offset:(NSInteger)offset
size:(NSInteger)size
offset:(NSUInteger)offset
size:(NSUInteger)size
error:(NSError **)error {
return [_ringBuffer loadFloatData:buffer.data
dataSize:buffer.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ NS_SWIFT_NAME(AudioClassifier)
* in initializing the audio classifier.
*/
+ (nullable instancetype)audioClassifierWithOptions:(TFLAudioClassifierOptions *)options
error:(NSError **)error;
error:(NSError **)error
NS_SWIFT_NAME(classifier(options:));

+ (instancetype)new NS_UNAVAILABLE;

Expand All @@ -89,7 +90,7 @@ NS_SWIFT_NAME(AudioClassifier)
* @return A `TFLAudioTensor` with the same buffer size as the model input tensor and audio format
* required by the model, if creation is successful otherwise nil.
*/
- (nullable TFLAudioTensor *)createInputAudioTensor;
- (TFLAudioTensor *)createInputAudioTensor;

/**
* Creates a `TFLAudioRecord` instance to start recording audio input from the microphone. The
Expand Down
28 changes: 28 additions & 0 deletions tensorflow_lite_support/ios/test/task/audio/audio_classifier/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@org_tensorflow//tensorflow/lite/ios:ios.bzl", "TFL_DEFAULT_TAGS", "TFL_DISABLED_SANITIZER_TAGS", "TFL_MINIMUM_OS_VERSION")
load("@build_bazel_rules_apple//apple:ios.bzl", "ios_unit_test")
load("@org_tensorflow//tensorflow/lite:special_rules.bzl", "tflite_ios_lab_runner")
load("@build_bazel_rules_swift//swift:swift.bzl", "swift_library")

package(
default_visibility = ["//visibility:private"],
Expand All @@ -17,6 +18,7 @@ objc_library(
],
tags = TFL_DEFAULT_TAGS,
deps = [
"//tensorflow_lite_support/ios:TFLCommon",
"//tensorflow_lite_support/ios/task/audio:TFLAudioClassifier",
"//tensorflow_lite_support/ios/test/task/audio/core/audio_record/utils:AVAudioPCMBufferUtils",
],
Expand All @@ -31,3 +33,29 @@ ios_unit_test(
":TFLAudioClassifierObjcTestLibrary",
],
)

swift_library(
name = "TFLAudioClassifierSwiftTestLibrary",
testonly = 1,
srcs = ["TFLAudioClassifierTests.swift"],
data = [
"//tensorflow_lite_support/cc/test/testdata/task/audio:test_audio_clips",
"//tensorflow_lite_support/cc/test/testdata/task/audio:test_models",
],
tags = TFL_DEFAULT_TAGS,
deps = [
"//tensorflow_lite_support/ios:TFLCommon",
"//tensorflow_lite_support/ios/task/audio:TFLAudioClassifier",
"//tensorflow_lite_support/ios/test/task/audio/core/audio_record/utils:AVAudioPCMBufferUtils",
],
)

ios_unit_test(
name = "TFLAudioClassifierSwiftTest",
minimum_os_version = TFL_MINIMUM_OS_VERSION,
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = TFL_DEFAULT_TAGS + TFL_DISABLED_SANITIZER_TAGS,
deps = [
":TFLAudioClassifierSwiftTestLibrary",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
==============================================================================*/
#import <XCTest/XCTest.h>

#import "tensorflow_lite_support/ios/sources/TFLCommon.h"
#import "tensorflow_lite_support/ios/task/audio/sources/TFLAudioClassifier.h"
#import "tensorflow_lite_support/ios/test/task/audio/core/audio_record/utils/sources/AVAudioPCMBuffer+Utils.h"

Expand Down Expand Up @@ -48,7 +49,7 @@ @interface TFLAudioClassifierTests : XCTestCase
@property(nonatomic) AVAudioFormat *audioEngineFormat;
@end

// This category of TFLAudioRecord is private to the current test file. This is needed in order to
// This category of TFLAudioRecord is private to the test files. This is needed in order to
// expose the method to load the audio record buffer without calling: -[TFLAudioRecord
// startRecordingWithError:]. This is needed to avoid exposing this method which isn't useful to the
// consumers of the framework.
Expand Down Expand Up @@ -123,6 +124,14 @@ - (TFLAudioClassifier *)createAudioClassifierWithModelPath:(NSString *)modelPath
return audioClassifier;
}

- (TFLAudioClassifier *)createAudioClassifierWithOptions:(TFLAudioClassifierOptions *)options {
TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
error:nil];
XCTAssertNotNil(audioClassifier);

return audioClassifier;
}

- (TFLAudioTensor *)createAudioTensorWithAudioClassifier:(TFLAudioClassifier *)audioClassifier {
// Create the audio tensor using audio classifier.
TFLAudioTensor *audioTensor = [audioClassifier createInputAudioTensor];
Expand All @@ -146,15 +155,15 @@ - (void)loadAudioTensor:(TFLAudioTensor *)audioTensor fromWavFileWithName:(NSStr
}

- (TFLClassificationResult *)classifyWithAudioClassifier:(TFLAudioClassifier *)audioClassifier
audioTensor:(TFLAudioTensor *)audioTensor {
audioTensor:(TFLAudioTensor *)audioTensor
expectedCategoryCount:(const NSInteger)expectedCategoryCount {
TFLClassificationResult *classificationResult =
[audioClassifier classifyWithAudioTensor:audioTensor error:nil];

const NSInteger expectedClassificationsCount = 1;
VerifyClassificationResult(classificationResult, expectedClassificationsCount);

const NSInteger expectedHeadIndex = 0;
const NSInteger expectedCategoryCount = 521;
VerifyClassifications(classificationResult.classifications[0], expectedHeadIndex,
expectedCategoryCount);

Expand All @@ -175,42 +184,36 @@ - (void)validateClassificationResultForInferenceWithFloatBuffer:
@"Inside, small room", // expectedLabel
nil // expectedDisplaName
);
VerifyCategory(categories[2],
3, // expectedIndex
0.003906, // expectedScore
@"Narration, monologue", // expectedLabel
nil // expectedDisplaName
);
}

- (void)validateCategoriesForInferenceWithAudioRecord:(NSArray<TFLCategory *> *)categories {
// The third category is different from the third category specified in -[TFLAudioClassifierTests
// validateClassificationResultForInferenceWithFloatBuffer]. This is because in case of inference
// with audio record involves more native internal conversions to mock the conversions done by the
// audio record as opposed to inference with float buffer where the number of native conversions are
// fewer. Since each native conversion by `AVAudioConverter` employ strategies to pick samples based
// on the format specified, the samples passed in for inference in case of float buffer and audio
// record will be slightly different.
// audio record as opposed to inference with float buffer where the number of native conversions
// are fewer. Since each native conversion by `AVAudioConverter` employ strategies to pick samples
// based on the format specified, the samples passed in for inference in case of float buffer and
// audio record will be slightly different.
VerifyCategory(categories[0],
0, // expectedIndex
0.957031, // expectedScore
@"Speech", // expectedLabel
nil // expectedDisplaName
);
VerifyCategory(categories[1],
500, // expectedIndex
0.015625, // expectedScore
500, // expectedIndex
0.019531, // expectedScore
@"Inside, small room", // expectedLabel
nil // expectedDisplaName
);
// The 3rd result is different from python tests because of the audio file format conversions are
// done using iOS native classes to mimic audio record behaviour. The iOS native classes handle
// audio format conversion differently as opposed to the task library C++ convenience method.
VerifyCategory(categories[2],
3, // expectedIndex
0.003906, // expectedScore
@"Narration, monologue", // expectedLabel
nil // expectedDisplaName
}

- (void)validateCategoriesForInferenceWithLabelDenyList:(NSArray<TFLCategory *> *)categories {
VerifyCategory(categories[0],
500, // expectedIndex
0.019531, // expectedScore
@"Inside, small room", // expectedLabel
nil // expectedDisplaName
);
}

Expand All @@ -228,8 +231,12 @@ - (void)testInferenceWithFloatBufferSucceeds {
TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
[self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

TFLClassificationResult *classificationResult = [self classifyWithAudioClassifier:audioClassifier
audioTensor:audioTensor];
const NSInteger expectedCategoryCount = 521;
TFLClassificationResult *classificationResult =
[self classifyWithAudioClassifier:audioClassifier
audioTensor:audioTensor
expectedCategoryCount:expectedCategoryCount];

[self validateClassificationResultForInferenceWithFloatBuffer:classificationResult
.classifications[0]
.categories];
Expand All @@ -246,13 +253,126 @@ - (void)testInferenceWithAudioRecordSucceeds {
// Load the audioRecord buffer into the audio tensor.
[audioTensor loadAudioRecord:audioRecord withError:nil];

TFLClassificationResult *classificationResult = [self classifyWithAudioClassifier:audioClassifier
audioTensor:audioTensor];
const NSInteger expectedCategoryCount = 521;
TFLClassificationResult *classificationResult =
[self classifyWithAudioClassifier:audioClassifier
audioTensor:audioTensor
expectedCategoryCount:expectedCategoryCount];

[self validateCategoriesForInferenceWithAudioRecord:classificationResult.classifications[0]
.categories];
}

- (void)testInferenceWithMaxResultsSucceeds {
const NSInteger maxResults = 3;
TFLAudioClassifierOptions *options =
[[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
options.classificationOptions.maxResults = maxResults;

TFLAudioClassifier *audioClassifier = [self createAudioClassifierWithOptions:options];
TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
[self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

TFLClassificationResult *classificationResult = [self classifyWithAudioClassifier:audioClassifier
audioTensor:audioTensor
expectedCategoryCount:maxResults];

[self validateClassificationResultForInferenceWithFloatBuffer:classificationResult
.classifications[0]
.categories];
}

- (void)testInferenceWithClassNameAllowListAndDenyListFails {
TFLAudioClassifierOptions *options =
[[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
options.classificationOptions.labelAllowList = @[ @"Speech" ];
options.classificationOptions.labelDenyList = @[ @"Inside, small room" ];

NSError *error = nil;
TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
error:&error];
XCTAssertNil(audioClassifier);
VerifyError(error,
expectedTaskErrorDomain, // expectedErrorDomain
TFLSupportErrorCodeInvalidArgumentError, // expectedErrorCode
@"INVALID_ARGUMENT: `class_name_allowlist` and `class_name_denylist` are mutually "
@"exclusive options." // expectedErrorMessage
);
}

- (void)testInferenceWithLabelAllowListSucceeds {
TFLAudioClassifierOptions *options =
[[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
options.classificationOptions.labelAllowList = @[ @"Speech", @"Inside, small room" ];

NSError *error = nil;
TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
error:&error];
TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
[self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

TFLClassificationResult *classificationResult =
[self classifyWithAudioClassifier:audioClassifier
audioTensor:audioTensor
expectedCategoryCount:options.classificationOptions.labelAllowList.count];

[self validateClassificationResultForInferenceWithFloatBuffer:classificationResult
.classifications[0]
.categories];
}

- (void)testInferenceWithLabelDenyListSucceeds {
TFLAudioClassifierOptions *options =
[[TFLAudioClassifierOptions alloc] initWithModelPath:self.modelPath];
options.classificationOptions.labelDenyList = @[ @"Speech" ];

NSError *error = nil;
TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:options
error:&error];
TFLAudioTensor *audioTensor = [self createAudioTensorWithAudioClassifier:audioClassifier];
[self loadAudioTensor:audioTensor fromWavFileWithName:@"speech"];

const NSInteger expectedCategoryCount = 520;
TFLClassificationResult *classificationResult =
[self classifyWithAudioClassifier:audioClassifier
audioTensor:audioTensor
expectedCategoryCount:expectedCategoryCount];

[self validateCategoriesForInferenceWithLabelDenyList:classificationResult.classifications[0]
.categories];
}

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnonnull"
- (void)testCreateAudioClassifierWithNilOptionsFails {
NSError *error = nil;
TFLAudioClassifier *audioClassifier = [TFLAudioClassifier audioClassifierWithOptions:nil
error:&error];

XCTAssertNil(audioClassifier);
VerifyError(error,
expectedTaskErrorDomain, // expectedErrorDomain
TFLSupportErrorCodeInvalidArgumentError, // expectedErrorCode
@"TFLAudioClassifierOptions argument cannot be nil." // expectedErrorMessage
);
}

- (void)testInferenceWithNilAudioTensorFails {
TFLAudioClassifier *audioClassifier = [self createAudioClassifierWithModelPath:self.modelPath];

NSError *error = nil;
TFLClassificationResult *classificationResult = [audioClassifier classifyWithAudioTensor:nil
error:&error];

XCTAssertNil(classificationResult);
VerifyError(error,
expectedTaskErrorDomain, // expectedErrorDomain
TFLSupportErrorCodeInvalidArgumentError, // expectedErrorCode
@"audioTensor argument cannot be nil." // expectedErrorMessage
);
}
#pragma clang diagnostic pop

@end

NS_ASSUME_NONNULL_END
Loading

0 comments on commit 9023f40

Please sign in to comment.