-
Notifications
You must be signed in to change notification settings - Fork 74k
/
array.h
156 lines (129 loc) · 5.05 KB
/
array.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_ARRAY_H_
#define TENSORFLOW_LITE_ARRAY_H_
#include <cstring>
#include <initializer_list>
#include <memory>
#include <type_traits>
#include <vector>
#include "tensorflow/lite/core/c/common.h"
namespace tflite {
/// TfLite*Array helpers
namespace array_internal {
// Function object used as a deleter for unique_ptr holding TFLite*Array
// objects.
struct TfLiteArrayDeleter {
void operator()(TfLiteIntArray* a);
void operator()(TfLiteFloatArray* a);
};
// Maps T to the corresponding TfLiteArray type.
template <class T>
struct TfLiteArrayInfo;
template <>
struct TfLiteArrayInfo<int> {
using Type = TfLiteIntArray;
};
template <>
struct TfLiteArrayInfo<float> {
using Type = TfLiteFloatArray;
};
} // namespace array_internal
template <class T>
using TfLiteArrayUniquePtr =
std::unique_ptr<typename array_internal::TfLiteArrayInfo<T>::Type,
array_internal::TfLiteArrayDeleter>;
// `unique_ptr` wrapper for `TfLiteIntArray`s.
using IntArrayUniquePtr = TfLiteArrayUniquePtr<int>;
// `unique_ptr` wrapper for `TfLiteFloatArray`s.
using FloatArrayUniquePtr = TfLiteArrayUniquePtr<float>;
// Allocates a TfLiteArray of given size using malloc.
//
// This builds an int array by default as this is the overwhelming part of the
// use cases.
template <class T = int>
TfLiteArrayUniquePtr<T> BuildTfLiteArray(int size);
// Allocates a TfLiteIntArray of given size using malloc.
template <>
inline IntArrayUniquePtr BuildTfLiteArray<int>(const int size) {
return IntArrayUniquePtr(TfLiteIntArrayCreate(size));
}
// Allocates a TfLiteFloatArray of given size using malloc.
template <>
inline FloatArrayUniquePtr BuildTfLiteArray<float>(const int size) {
return FloatArrayUniquePtr(TfLiteFloatArrayCreate(size));
}
// Allocates a TFLiteArray of given size and initializes it with the given
// values.
//
// `values` is expected to holds `size` elements.
//
// If T is explicitely specified and the type of values is not the same as T,
// then a static_cast is performed.
template <class T = void, class U,
class Type = std::conditional_t<std::is_same<T, void>::value, U, T>>
TfLiteArrayUniquePtr<Type> BuildTfLiteArray(const int size,
const U* const values) {
TfLiteArrayUniquePtr<Type> array = BuildTfLiteArray<Type>(size);
// If size is 0, the array pointer may be null.
if (array && values) {
if (std::is_same<Type, U>::value) {
memcpy(array->data, values, size * sizeof(Type));
} else {
for (int i = 0; i < size; ++i) {
array->data[i] = static_cast<Type>(values[i]);
}
}
}
return array;
}
// Allocates a TFLiteArray and initializes it with the given array.
//
// `values` is expected to holds `size` elements.
template <class T, size_t N>
TfLiteArrayUniquePtr<T> BuildTfLiteArray(const T (&values)[N]) {
return BuildTfLiteArray<T>(static_cast<int>(N), values);
}
// Allocates a TFLiteArray and initializes it with the given values.
//
// This uses SFINAE to only be picked up by for types that implement `data()`
// and `size()` member functions. We cannot reuse detection facilities provided
// by Abseil in this code.
//
// To conform with the other overloads, we allow specifying the type of the
// array as well as deducing it from the container.
template <
class T = void, class Container,
class ElementType =
std::decay_t<decltype(*std::declval<Container>().data())>,
class SizeType = std::decay_t<decltype(std::declval<Container>().size())>,
class Type =
std::conditional_t<std::is_same<T, void>::value, ElementType, T>>
TfLiteArrayUniquePtr<Type> BuildTfLiteArray(const Container& values) {
return BuildTfLiteArray<Type>(static_cast<int>(values.size()), values.data());
}
// Allocates a TFLiteArray and initializes it with the given values.
template <class T>
TfLiteArrayUniquePtr<T> BuildTfLiteArray(
const std::initializer_list<T>& values) {
return BuildTfLiteArray(static_cast<int>(values.size()), values.begin());
}
// Allocates a TFLiteArray and initializes it with the given array.
inline IntArrayUniquePtr BuildTfLiteArray(const TfLiteIntArray& other) {
return BuildTfLiteArray(other.size, other.data);
}
// Allocates a TFLiteArray and initializes it with the given array.
inline FloatArrayUniquePtr BuildTfLiteArray(const TfLiteFloatArray& other) {
return BuildTfLiteArray(other.size, other.data);
}
} // namespace tflite
#endif // TENSORFLOW_LITE_ARRAY_H_