-
Notifications
You must be signed in to change notification settings - Fork 7.4k
/
image_searcher_dataloader.py
155 lines (133 loc) · 6.1 KB
/
image_searcher_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image DataLoader for Searcher task."""
import logging
import os
import numpy as np
import tensorflow as tf
from tensorflow_examples.lite.model_maker.core.api.api_util import mm_export
from tensorflow_examples.lite.model_maker.core.data_util import metadata_loader
from tensorflow_examples.lite.model_maker.core.data_util import searcher_dataloader
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2
from tensorflow_lite_support.python.task.vision import image_embedder
from tensorflow_lite_support.python.task.vision.core import tensor_image
_MetadataType = metadata_loader.MetadataType
_BaseOptions = base_options_module.BaseOptions
@mm_export("searcher.ImageDataLoader")
class DataLoader(searcher_dataloader.DataLoader):
"""DataLoader class for Image Searcher Task."""
def __init__(
self,
embedder: image_embedder.ImageEmbedder,
metadata_type: _MetadataType = _MetadataType.FROM_FILE_NAME) -> None:
"""Initializes DataLoader for Image Searcher task.
Args:
embedder: Embedder to generate embedding from raw input image.
metadata_type: Type of MetadataLoader to load metadata for each input
data. By default, load the file name as metadata for each input data.
"""
self._embedder = embedder
super().__init__(embedder_path=embedder.options.base_options.file_name)
# Creates the metadata loader.
self.metadata_type = metadata_type
if metadata_type is _MetadataType.FROM_FILE_NAME:
self._metadata_loader = metadata_loader.MetadataLoader.from_file_name()
elif metadata_type is _MetadataType.FROM_DAT_FILE:
self._metadata_loader = metadata_loader.MetadataLoader.from_dat_file()
else:
raise ValueError("Unsuported metadata_type.")
@classmethod
def create(cls,
image_embedder_path: str,
metadata_type: _MetadataType = _MetadataType.FROM_FILE_NAME,
l2_normalize: bool = False) -> "DataLoader":
"""Creates DataLoader for the Image Searcher task.
Args:
image_embedder_path: Path to the ".tflite" image embedder model.
metadata_type: Type of MetadataLoader to load metadata for each input
image based on image path. By default, load the file name as metadata
for each input image.
l2_normalize: Whether to normalize the returned feature vector with L2
norm. Use this option only if the model does not already contain a
native L2_NORMALIZATION TF Lite Op. In most cases, this is already the
case and L2 norm is thus achieved through TF Lite inference.
Returns:
DataLoader object created for the Image Searcher task.
"""
# Creates ImageEmbedder.
image_embedder_path = os.path.abspath(image_embedder_path)
with tf.io.gfile.GFile(image_embedder_path, "rb") as f:
image_embedder_content = f.read()
base_options = _BaseOptions(
file_content=image_embedder_content, file_name=image_embedder_path)
embedding_options = embedding_options_pb2.EmbeddingOptions(
l2_normalize=l2_normalize)
options = image_embedder.ImageEmbedderOptions(
base_options=base_options, embedding_options=embedding_options)
embedder = image_embedder.ImageEmbedder.create_from_options(options)
return cls(embedder, metadata_type)
def load_from_folder(self, path: str, mode: str = "r") -> None:
"""Loads image data from folder.
Users can load images from different folders one by one. For instance,
```
# Creates data_loader instance.
data_loader = image_searcher_dataloader.DataLoader.create(tflite_path)
# Loads images, first from `image_path1` and secondly from `image_path2`.
data_loader.load_from_folder(image_path1)
data_loader.load_from_folder(image_path2)
```
Args:
path: image directory to be loaded.
mode: mode in which the file is opened, Used when metadata_type is
FROM_DAT_FILE. Only 'r' and 'rb' are supported. 'r' means opening for
reading, 'rb' means opening for reading binary.
"""
embedding_list = []
metadata_list = []
i = 0
# Gets the image files in the folder and loads images.
for root, _, files in tf.io.gfile.walk(path):
for name in files:
image_path = os.path.join(root, name)
if image_path.lower().endswith(".dat"):
continue
try:
with tf.io.gfile.GFile(image_path, "rb") as f:
buffer = f.read()
image = tensor_image.TensorImage.create_from_buffer(buffer)
except RuntimeError as e:
logging.warning(
"Can't read image from the image path %s with the error %s",
image_path, e)
continue
try:
embedding = self._embedder.embed(
image).embeddings[0].feature_vector.value
except (RuntimeError, ValueError) as e:
logging.warning("Can't get the embedding of %s with the error %s",
image_path, e)
continue
embedding_list.append(embedding)
if self.metadata_type == _MetadataType.FROM_DAT_FILE:
metadata = self._metadata_loader.load(image_path, mode=mode)
else:
metadata = self._metadata_loader.load(image_path)
metadata_list.append(metadata)
i += 1
if i % 1000 == 0:
logging.info("Processed %d images.", i)
cache_dataset = np.stack(embedding_list)
self._cache_dataset_list.append(cache_dataset)
self._metadata = self._metadata + metadata_list