-
Notifications
You must be signed in to change notification settings - Fork 124
/
metadata_info.py
817 lines (685 loc) · 32 KB
/
metadata_info.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper classes for common model metadata information."""
import collections
import csv
import os
from typing import List, Optional, Type, Union
from tensorflow_lite_support.metadata import metadata_schema_py_generated as _metadata_fb
from tensorflow_lite_support.metadata import schema_py_generated as _schema_fb
from tensorflow_lite_support.metadata.python.metadata_writers import writer_utils
# Min and max values for UINT8 tensors.
_MIN_UINT8 = 0
_MAX_UINT8 = 255
# Default description for vocabulary files.
_VOCAB_FILE_DESCRIPTION = ("Vocabulary file to convert natural language "
"words to embedding vectors.")
class GeneralMd:
"""A container for common metadata information of a model.
Attributes:
name: name of the model.
version: version of the model.
description: description of what the model does.
author: author of the model.
licenses: licenses of the model.
"""
def __init__(self,
name: Optional[str] = None,
version: Optional[str] = None,
description: Optional[str] = None,
author: Optional[str] = None,
licenses: Optional[str] = None):
self.name = name
self.version = version
self.description = description
self.author = author
self.licenses = licenses
def create_metadata(self) -> _metadata_fb.ModelMetadataT:
"""Creates the model metadata based on the general model information.
Returns:
A Flatbuffers Python object of the model metadata.
"""
model_metadata = _metadata_fb.ModelMetadataT()
model_metadata.name = self.name
model_metadata.version = self.version
model_metadata.description = self.description
model_metadata.author = self.author
model_metadata.license = self.licenses
return model_metadata
class AssociatedFileMd:
"""A container for common associated file metadata information.
Attributes:
file_path: path to the associated file.
description: description of the associated file.
file_type: file type of the associated file [1].
locale: locale of the associated file [2].
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L77
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L154
"""
def __init__(
self,
file_path: str,
description: Optional[str] = None,
file_type: Optional[_metadata_fb.AssociatedFileType] = _metadata_fb
.AssociatedFileType.UNKNOWN,
locale: Optional[str] = None):
self.file_path = file_path
self.description = description
self.file_type = file_type
self.locale = locale
def create_metadata(self) -> _metadata_fb.AssociatedFileT:
"""Creates the associated file metadata.
Returns:
A Flatbuffers Python object of the associated file metadata.
"""
file_metadata = _metadata_fb.AssociatedFileT()
file_metadata.name = os.path.basename(self.file_path)
file_metadata.description = self.description
file_metadata.type = self.file_type
file_metadata.locale = self.locale
return file_metadata
class LabelFileMd(AssociatedFileMd):
"""A container for label file metadata information."""
_LABEL_FILE_DESCRIPTION = ("Labels for categories that the model can "
"recognize.")
_FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS
def __init__(self, file_path: str, locale: Optional[str] = None):
"""Creates a LabelFileMd object.
Args:
file_path: file_path of the label file.
locale: locale of the label file [1].
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L154
"""
super().__init__(file_path, self._LABEL_FILE_DESCRIPTION, self._FILE_TYPE,
locale)
class RegexTokenizerMd:
"""A container for the Regex tokenizer [1] metadata information.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L459
"""
def __init__(self, delim_regex_pattern: str, vocab_file_path: str):
"""Initializes a RegexTokenizerMd object.
Args:
delim_regex_pattern: the regular expression to segment strings and create
tokens.
vocab_file_path: path to the vocabulary file.
"""
self._delim_regex_pattern = delim_regex_pattern
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Bert tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Bert tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
# Create the RegexTokenizer.
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.RegexTokenizerOptions)
tokenizer.options = _metadata_fb.RegexTokenizerOptionsT()
tokenizer.options.delimRegexPattern = self._delim_regex_pattern
tokenizer.options.vocabFile = [vocab]
return tokenizer
class BertTokenizerMd:
"""A container for the Bert tokenizer [1] metadata information.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
"""
def __init__(self, vocab_file_path: str):
"""Initializes a BertTokenizerMd object.
Args:
vocab_file_path: path to the vocabulary file.
"""
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the Bert tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the Bert tokenizer metadata.
"""
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = _VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = _metadata_fb.ProcessUnitOptions.BertTokenizerOptions
tokenizer.options = _metadata_fb.BertTokenizerOptionsT()
tokenizer.options.vocabFile = [vocab]
return tokenizer
class SentencePieceTokenizerMd:
"""A container for the sentence piece tokenizer [1] metadata information.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
"""
_SP_MODEL_DESCRIPTION = "The sentence piece model file."
_SP_VOCAB_FILE_DESCRIPTION = _VOCAB_FILE_DESCRIPTION + (
" This file is optional during tokenization, while the sentence piece "
"model is mandatory.")
def __init__(self,
sentence_piece_model_path: str,
vocab_file_path: Optional[str] = None):
"""Initializes a SentencePieceTokenizerMd object.
Args:
sentence_piece_model_path: path to the sentence piece model file.
vocab_file_path: path to the vocabulary file.
"""
self._sentence_piece_model_path = sentence_piece_model_path
self._vocab_file_path = vocab_file_path
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the sentence piece tokenizer metadata based on the information.
Returns:
A Flatbuffers Python object of the sentence piece tokenizer metadata.
"""
tokenizer = _metadata_fb.ProcessUnitT()
tokenizer.optionsType = (
_metadata_fb.ProcessUnitOptions.SentencePieceTokenizerOptions)
tokenizer.options = _metadata_fb.SentencePieceTokenizerOptionsT()
sp_model = _metadata_fb.AssociatedFileT()
sp_model.name = self._sentence_piece_model_path
sp_model.description = self._SP_MODEL_DESCRIPTION
tokenizer.options.sentencePieceModel = [sp_model]
if self._vocab_file_path:
vocab = _metadata_fb.AssociatedFileT()
vocab.name = self._vocab_file_path
vocab.description = self._SP_VOCAB_FILE_DESCRIPTION
vocab.type = _metadata_fb.AssociatedFileType.VOCABULARY
tokenizer.options.vocabFile = [vocab]
return tokenizer
class ScoreCalibrationMd:
"""A container for score calibration [1] metadata information.
[1]:
https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434
"""
_SCORE_CALIBRATION_FILE_DESCRIPTION = (
"Contains sigmoid-based score calibration parameters. The main purposes "
"of score calibration is to make scores across classes comparable, so "
"that a common threshold can be used for all output classes.")
_FILE_TYPE = _metadata_fb.AssociatedFileType.TENSOR_AXIS_SCORE_CALIBRATION
def __init__(self,
score_transformation_type: _metadata_fb.ScoreTransformationType,
default_score: float, file_path: str):
"""Creates a ScoreCalibrationMd object.
Args:
score_transformation_type: type of the function used for transforming the
uncalibrated score before applying score calibration.
default_score: the default calibrated score to apply if the uncalibrated
score is below min_score or if no parameters were specified for a given
index.
file_path: file_path of the score calibration file [1].
[1]:
https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L122
Raises:
ValueError: if the score_calibration file is malformed.
"""
self._score_transformation_type = score_transformation_type
self._default_score = default_score
self._file_path = file_path
# Sanity check the score calibration file.
with open(self._file_path) as calibration_file:
csv_reader = csv.reader(calibration_file, delimiter=",")
for row in csv_reader:
if row and len(row) != 3 and len(row) != 4:
raise ValueError(
f"Expected empty lines or 3 or 4 parameters per line in score"
f" calibration file, but got {len(row)}.")
if row and float(row[0]) < 0:
raise ValueError(
f"Expected scale to be a non-negative value, but got "
f"{float(row[0])}.")
def create_metadata(self) -> _metadata_fb.ProcessUnitT:
"""Creates the score calibration metadata based on the information.
Returns:
A Flatbuffers Python object of the score calibration metadata.
"""
score_calibration = _metadata_fb.ProcessUnitT()
score_calibration.optionsType = (
_metadata_fb.ProcessUnitOptions.ScoreCalibrationOptions)
options = _metadata_fb.ScoreCalibrationOptionsT()
options.scoreTransformation = self._score_transformation_type
options.defaultScore = self._default_score
score_calibration.options = options
return score_calibration
def create_score_calibration_file_md(self) -> AssociatedFileMd:
return AssociatedFileMd(self._file_path,
self._SCORE_CALIBRATION_FILE_DESCRIPTION,
self._FILE_TYPE)
class TensorMd:
"""A container for common tensor metadata information.
Attributes:
name: name of the tensor.
description: description of what the tensor is.
min_values: per-channel minimum value of the tensor.
max_values: per-channel maximum value of the tensor.
content_type: content_type of the tensor.
associated_files: information of the associated files in the tensor.
tensor_name: name of the corresponding tensor [1] in the TFLite model. It is
used to locate the corresponding tensor and decide the order of the tensor
metadata [2] when populating model metadata.
[1]:
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[2]:
https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L595-L612
"""
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
min_values: Optional[List[float]] = None,
max_values: Optional[List[float]] = None,
content_type: _metadata_fb.ContentProperties = _metadata_fb
.ContentProperties.FeatureProperties,
associated_files: Optional[List[Type[AssociatedFileMd]]] = None,
tensor_name: Optional[str] = None):
self.name = name
self.description = description
self.min_values = min_values
self.max_values = max_values
self.content_type = content_type
self.associated_files = associated_files
self.tensor_name = tensor_name
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input tensor metadata based on the information.
Returns:
A Flatbuffers Python object of the input metadata.
"""
tensor_metadata = _metadata_fb.TensorMetadataT()
tensor_metadata.name = self.name
tensor_metadata.description = self.description
# Create min and max values
stats = _metadata_fb.StatsT()
stats.max = self.max_values
stats.min = self.min_values
tensor_metadata.stats = stats
# Create content properties
content = _metadata_fb.ContentT()
if self.content_type is _metadata_fb.ContentProperties.FeatureProperties:
content.contentProperties = _metadata_fb.FeaturePropertiesT()
elif self.content_type is _metadata_fb.ContentProperties.ImageProperties:
content.contentProperties = _metadata_fb.ImagePropertiesT()
elif self.content_type is (
_metadata_fb.ContentProperties.BoundingBoxProperties):
content.contentProperties = _metadata_fb.BoundingBoxPropertiesT()
elif self.content_type is _metadata_fb.ContentProperties.AudioProperties:
content.contentProperties = _metadata_fb.AudioPropertiesT()
content.contentPropertiesType = self.content_type
tensor_metadata.content = content
# TODO(b/174091474): check if multiple label files have populated locale.
# Create associated files
if self.associated_files:
tensor_metadata.associatedFiles = [
file.create_metadata() for file in self.associated_files
]
return tensor_metadata
class InputImageTensorMd(TensorMd):
"""A container for input image tensor metadata information.
Attributes:
norm_mean: the mean value used in tensor normalization [1].
norm_std: the std value used in the tensor normalization [1]. norm_mean and
norm_std must have the same dimension.
color_space_type: the color space type of the input image [2].
[1]:
https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L172
"""
# Min and max float values for image pixels.
_MIN_PIXEL = 0.0
_MAX_PIXEL = 255.0
def __init__(
self,
name: Optional[str] = None,
description: Optional[str] = None,
norm_mean: Optional[List[float]] = None,
norm_std: Optional[List[float]] = None,
color_space_type: Optional[
_metadata_fb.ColorSpaceType] = _metadata_fb.ColorSpaceType.UNKNOWN,
tensor_type: Optional[_schema_fb.TensorType] = None):
"""Initializes the instance of InputImageTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
norm_mean: the mean value used in tensor normalization [1].
norm_std: the std value used in the tensor normalization [1]. norm_mean
and norm_std must have the same dimension.
color_space_type: the color space type of the input image [2].
tensor_type: data type of the tensor.
[1]:
https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L172
Raises:
ValueError: if norm_mean and norm_std have different dimensions.
"""
if norm_std and norm_mean and len(norm_std) != len(norm_mean):
raise ValueError(
f"norm_mean and norm_std are expected to be the same dim. But got "
f"{len(norm_mean)} and {len(norm_std)}")
if tensor_type is _schema_fb.TensorType.UINT8:
min_values = [_MIN_UINT8]
max_values = [_MAX_UINT8]
elif tensor_type is _schema_fb.TensorType.FLOAT32 and norm_std and norm_mean:
min_values = [
float(self._MIN_PIXEL - mean) / std
for mean, std in zip(norm_mean, norm_std)
]
max_values = [
float(self._MAX_PIXEL - mean) / std
for mean, std in zip(norm_mean, norm_std)
]
else:
# Uint8 and Float32 are the two major types currently. And Task library
# doesn't support other types so far.
min_values = None
max_values = None
super().__init__(name, description, min_values, max_values,
_metadata_fb.ContentProperties.ImageProperties)
self.norm_mean = norm_mean
self.norm_std = norm_std
self.color_space_type = color_space_type
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input image metadata based on the information.
Returns:
A Flatbuffers Python object of the input image metadata.
"""
tensor_metadata = super().create_metadata()
tensor_metadata.content.contentProperties.colorSpace = self.color_space_type
# Create normalization parameters
if self.norm_mean and self.norm_std:
normalization = _metadata_fb.ProcessUnitT()
normalization.optionsType = (
_metadata_fb.ProcessUnitOptions.NormalizationOptions)
normalization.options = _metadata_fb.NormalizationOptionsT()
normalization.options.mean = self.norm_mean
normalization.options.std = self.norm_std
tensor_metadata.processUnits = [normalization]
return tensor_metadata
class InputTextTensorMd(TensorMd):
"""A container for the input text tensor metadata information.
Attributes:
tokenizer_md: information of the tokenizer in the input text tensor, if any.
"""
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
tokenizer_md: Optional[RegexTokenizerMd] = None):
"""Initializes the instance of InputTextTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
tokenizer_md: information of the tokenizer in the input text tensor, if
any. Only `RegexTokenizer` [1] is currenly supported. If the tokenizer
is `BertTokenizer` [2] or `SentencePieceTokenizer` [3], refer to
`bert_nl_classifier.MetadataWriter`.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
[3]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
"""
super().__init__(name, description)
self.tokenizer_md = tokenizer_md
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input text metadata based on the information.
Returns:
A Flatbuffers Python object of the input text metadata.
Raises:
ValueError: if the type of tokenizer_md is unsupported.
"""
if not isinstance(self.tokenizer_md, (type(None), RegexTokenizerMd)):
raise ValueError(
f"The type of tokenizer_options, {type(self.tokenizer_md)}, is "
f"unsupported")
tensor_metadata = super().create_metadata()
if self.tokenizer_md:
tensor_metadata.processUnits = [self.tokenizer_md.create_metadata()]
return tensor_metadata
class InputAudioTensorMd(TensorMd):
"""A container for the input audio tensor metadata information.
Attributes:
sample_rate: the sample rate in Hz when the audio was captured.
channels: the channel count of the audio.
"""
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
sample_rate: int = 0,
channels: int = 0):
"""Initializes the instance of InputAudioTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
sample_rate: the sample rate in Hz when the audio was captured.
channels: the channel count of the audio.
"""
super().__init__(
name,
description,
content_type=_metadata_fb.ContentProperties.AudioProperties)
self.sample_rate = sample_rate
self.channels = channels
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the input audio metadata based on the information.
Returns:
A Flatbuffers Python object of the input audio metadata.
Raises:
ValueError: if any value of sample_rate, channels is negative.
"""
# 0 is the default value in Flatbuffers.
if self.sample_rate < 0:
raise ValueError(
f"sample_rate should be non-negative, but got {self.sample_rate}.")
if self.channels < 0:
raise ValueError(
f"channels should be non-negative, but got {self.channels}.")
tensor_metadata = super().create_metadata()
properties = tensor_metadata.content.contentProperties
properties.sampleRate = self.sample_rate
properties.channels = self.channels
return tensor_metadata
class ClassificationTensorMd(TensorMd):
"""A container for the classification tensor metadata information.
Attributes:
label_files: information of the label files [1] in the classification
tensor.
score_calibration_md: information of the score calibration operation [2] in
the classification tensor.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95
[2]:
https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434
"""
# Min and max float values for classification results.
_MIN_FLOAT = 0.0
_MAX_FLOAT = 1.0
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None,
tensor_type: Optional[_schema_fb.TensorType] = None,
score_calibration_md: Optional[ScoreCalibrationMd] = None,
tensor_name: Optional[str] = None):
"""Initializes the instance of ClassificationTensorMd.
Args:
name: name of the tensor.
description: description of what the tensor is.
label_files: information of the label files [1] in the classification
tensor.
tensor_type: data type of the tensor.
score_calibration_md: information of the score calibration files operation
[2] in the classification tensor.
tensor_name: name of the corresponding tensor [3] in the TFLite model. It
is used to locate the corresponding classification tensor and decide the
order of the tensor metadata [4] when populating model metadata.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L95
[2]:
https://github.com/tensorflow/tflite-support/blob/5e0cdf5460788c481f5cd18aab8728ec36cf9733/tensorflow_lite_support/metadata/metadata_schema.fbs#L434
[3]:
https://github.com/tensorflow/tensorflow/blob/cb67fef35567298b40ac166b0581cd8ad68e5a3a/tensorflow/lite/schema/schema.fbs#L1129-L1136
[4]:
https://github.com/tensorflow/tflite-support/blob/b2a509716a2d71dfff706468680a729cc1604cff/tensorflow_lite_support/metadata/metadata_schema.fbs#L595-L612
"""
self.score_calibration_md = score_calibration_md
if tensor_type is _schema_fb.TensorType.UINT8:
min_values = [_MIN_UINT8]
max_values = [_MAX_UINT8]
elif tensor_type is _schema_fb.TensorType.FLOAT32:
min_values = [self._MIN_FLOAT]
max_values = [self._MAX_FLOAT]
else:
# Uint8 and Float32 are the two major types currently. And Task library
# doesn't support other types so far.
min_values = None
max_values = None
associated_files = label_files or []
if self.score_calibration_md:
associated_files.append(
score_calibration_md.create_score_calibration_file_md())
super().__init__(name, description, min_values, max_values,
_metadata_fb.ContentProperties.FeatureProperties,
associated_files, tensor_name)
def create_metadata(self) -> _metadata_fb.TensorMetadataT:
"""Creates the classification tensor metadata based on the information."""
tensor_metadata = super().create_metadata()
if self.score_calibration_md:
tensor_metadata.processUnits = [
self.score_calibration_md.create_metadata()
]
return tensor_metadata
class CategoryTensorMd(TensorMd):
"""A container for the category tensor metadata information."""
def __init__(self,
name: Optional[str] = None,
description: Optional[str] = None,
label_files: Optional[List[LabelFileMd]] = None):
"""Initializes a CategoryTensorMd object.
Args:
name: name of the tensor.
description: description of what the tensor is.
label_files: information of the label files [1] in the category tensor.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L108
"""
# In category tensors, label files are in the type of TENSOR_VALUE_LABELS.
value_label_files = label_files
if value_label_files:
for file in value_label_files:
file.file_type = _metadata_fb.AssociatedFileType.TENSOR_VALUE_LABELS
super().__init__(
name=name, description=description, associated_files=value_label_files)
class BertInputTensorsMd:
"""A container for the input tensor metadata information of Bert models."""
_IDS_NAME = "ids"
_IDS_DESCRIPTION = "Tokenized ids of the input text."
_MASK_NAME = "mask"
_MASK_DESCRIPTION = ("Mask with 1 for real tokens and 0 for padding "
"tokens.")
_SEGMENT_IDS_NAME = "segment_ids"
_SEGMENT_IDS_DESCRIPTION = (
"0 for the first sequence, 1 for the second sequence if exists.")
def __init__(self,
model_buffer: bytearray,
ids_name: str,
mask_name: str,
segment_name: str,
ids_md: Optional[TensorMd] = None,
mask_md: Optional[TensorMd] = None,
segment_ids_md: Optional[TensorMd] = None,
tokenizer_md: Union[None, BertTokenizerMd,
SentencePieceTokenizerMd] = None):
"""Initializes a BertInputTensorsMd object.
`ids_name`, `mask_name`, and `segment_name` correspond to the `Tensor.name`
in the TFLite schema, which help to determine the tensor order when
populating metadata.
Args:
model_buffer: valid buffer of the model file.
ids_name: name of the ids tensor, which represents the tokenized ids of
the input text.
mask_name: name of the mask tensor, which represents the mask with 1 for
real tokens and 0 for padding tokens.
segment_name: name of the segment ids tensor, where `0` stands for the
first sequence, and `1` stands for the second sequence if exists.
ids_md: input ids tensor informaton.
mask_md: input mask tensor informaton.
segment_ids_md: input segment tensor informaton.
tokenizer_md: information of the tokenizer used to process the input
string, if any. Supported tokenziers are: `BertTokenizer` [1] and
`SentencePieceTokenizer` [2]. If the tokenizer is `RegexTokenizer`
[3], refer to `nl_classifier.MetadataWriter`.
[1]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L436
[2]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L473
[3]:
https://github.com/tensorflow/tflite-support/blob/b80289c4cd1224d0e1836c7654e82f070f9eefaa/tensorflow_lite_support/metadata/metadata_schema.fbs#L475
"""
self._input_names = [ids_name, mask_name, segment_name]
# Get the input tensor names in order from the model. Later, we need to
# order the input metadata according to this tensor order.
self._ordered_input_names = writer_utils.get_input_tensor_names(
model_buffer)
# Verify that self._ordered_input_names (read from the model) and
# self._input_name (collected from users) are aligned.
if collections.Counter(self._ordered_input_names) != collections.Counter(
self._input_names):
raise ValueError(
f"The input tensor names ({self._ordered_input_names}) do not match "
f"the tensor names read from the model ({self._input_names}).")
if ids_md is None:
ids_md = TensorMd(name=self._IDS_NAME, description=self._IDS_DESCRIPTION)
if mask_md is None:
mask_md = TensorMd(
name=self._MASK_NAME, description=self._MASK_DESCRIPTION)
if segment_ids_md is None:
segment_ids_md = TensorMd(
name=self._SEGMENT_IDS_NAME,
description=self._SEGMENT_IDS_DESCRIPTION)
# The order of self._input_md matches the order of self._input_names.
self._input_md = [ids_md, mask_md, segment_ids_md]
if not isinstance(tokenizer_md,
(type(None), BertTokenizerMd, SentencePieceTokenizerMd)):
raise ValueError(
f"The type of tokenizer_options, {type(tokenizer_md)}, is unsupported"
)
self._tokenizer_md = tokenizer_md
def create_input_tesnor_metadata(self) -> List[_metadata_fb.TensorMetadataT]:
"""Creates the input metadata for the three input tesnors."""
# The order of the three input tensors may vary with each model conversion.
# We need to order the input metadata according to the tensor order in the
# model.
ordered_metadata = []
name_md_dict = dict(zip(self._input_names, self._input_md))
for name in self._ordered_input_names:
ordered_metadata.append(name_md_dict[name].create_metadata())
return ordered_metadata
def create_input_process_unit_metadata(
self) -> List[_metadata_fb.ProcessUnitT]:
"""Creates the input process unit metadata."""
if self._tokenizer_md:
return [self._tokenizer_md.create_metadata()]
else:
return []
def get_tokenizer_associated_files(self) -> List[str]:
"""Gets the associated files that are packed in the tokenizer."""
if self._tokenizer_md:
return writer_utils.get_tokenizer_associated_files(
self._tokenizer_md.create_metadata().options)
else:
return []