[go: nahoru, domu]

Skip to content

Commit

Permalink
TNN convert 工具支持导出 fp16 模型 (#1875)
Browse files Browse the repository at this point in the history
* [Convert][UPD]1. support save fp16 model;
  • Loading branch information
gttiankai committed Feb 1, 2023
1 parent 3a96c41 commit 30863dc
Show file tree
Hide file tree
Showing 26 changed files with 589 additions and 107 deletions.
6 changes: 6 additions & 0 deletions source/tnn/interpreter/raw_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ RawBuffer ConvertHalfHandle(RawBuffer &buf) {
auto data_count = buf.GetDataCount();
RawBuffer buf_f32(data_count * sizeof(float));
ConvertFromHalfToFloat(buf.force_to<void *>(), buf_f32.force_to<float *>(), data_count);
buf_f32.SetDataType(DATA_TYPE_FLOAT);
buf_f32.SetBufferDims(buf.GetBufferDims());
return buf_f32;
} else {
return buf;
Expand All @@ -188,6 +190,7 @@ RawBuffer ConvertFloatToBFP16(RawBuffer &buf) {
RawBuffer buf_bfp16(data_count * sizeof(bfp16_t));
ConvertFromFloatToBFP16(buf.force_to<float *>(), buf_bfp16.force_to<void *>(), data_count);
buf_bfp16.SetDataType(DATA_TYPE_BFP16);
buf_bfp16.SetBufferDims(buf.GetBufferDims());
return buf_bfp16;
} else {
return buf;
Expand All @@ -204,6 +207,7 @@ RawBuffer ConvertHalfToBFP16(RawBuffer &buf) {
RawBuffer buf_bfp16(data_count * sizeof(bfp16_t));
ConvertFromFloatToBFP16(buf_fp32.force_to<float *>(), buf_bfp16.force_to<void *>(), data_count);
buf_bfp16.SetDataType(DATA_TYPE_BFP16);
buf_bfp16.SetBufferDims(buf.GetBufferDims());
return buf_bfp16;
} else {
return buf;
Expand Down Expand Up @@ -238,6 +242,8 @@ RawBuffer ConvertFloatToFP16(RawBuffer &buf) {
if (buf.GetBytesSize() > 0 && buf.GetDataType() == DATA_TYPE_FLOAT) {
int data_count = buf.GetDataCount();
RawBuffer buf_fp16(data_count * sizeof(fp16_t));
buf_fp16.SetDataType(DATA_TYPE_HALF);
buf_fp16.SetBufferDims(buf.GetBufferDims());
ConvertFromFloatToHalf(buf.force_to<float *>(), buf_fp16.force_to<fp16_t *>(), data_count);
return buf_fp16;
} else {
Expand Down
45 changes: 45 additions & 0 deletions tools/converter/source/resource/reource_base_convert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "reource_base_convert.h"
namespace TNN_CONVERTER {

ResourceConvertManager *ResourceConvertManager::resource_convert_manager_ = nullptr;

TNN_CONVERTER::ResourceConvertManager *TNN_CONVERTER::ResourceConvertManager::get() {
if (resource_convert_manager_ == nullptr) {
resource_convert_manager_ = new ResourceConvertManager;
}
return resource_convert_manager_;
}
void TNN_CONVERTER::ResourceConvertManager::insert(const std::string &tnn_op_name,
TNN_CONVERTER::ResourceBaseConvert *resource_convert) {
resource_convert_map_.insert(std::make_pair(tnn_op_name, resource_convert));
}

TNN_CONVERTER::ResourceBaseConvert *TNN_CONVERTER::ResourceConvertManager::search(const std::string &tnn_op_name) {
auto iter = resource_convert_map_.find(tnn_op_name);
if (iter == resource_convert_map_.end()) {
return nullptr;
}
return iter->second;
}
ResourceConvertManager::~ResourceConvertManager() {
for (auto &iter : resource_convert_map_) {
delete iter.second;
}
resource_convert_map_.clear();
delete resource_convert_manager_;
}
} // namespace TNN_CONVERTER
70 changes: 70 additions & 0 deletions tools/converter/source/resource/reource_base_convert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef TNN_TOOLS_CONVERTER_SOURCE_RESOURCE_REOURCE_BASE_CONVERT_H_
#define TNN_TOOLS_CONVERTER_SOURCE_RESOURCE_REOURCE_BASE_CONVERT_H_
#include <map>

#include "tnn/core/common.h"
#include "tnn/core/status.h"
#include "tnn/interpreter/layer_param.h"
#include "tnn/interpreter/net_resource.h"
#include "tnn/interpreter/net_structure.h"

namespace TNN_CONVERTER {
class ResourceBaseConvert {
public:
ResourceBaseConvert() = default;
virtual ~ResourceBaseConvert() = default;
virtual TNN_NS::Status ConvertToHalfResource(std::shared_ptr<TNN_NS::LayerParam> param,
std::shared_ptr<TNN_NS::LayerResource> layer_resource) = 0;
};

class ResourceConvertManager {
public:
ResourceConvertManager() = default;
~ResourceConvertManager();
static ResourceConvertManager* get();
void insert(const std::string& tnn_op_name, ResourceBaseConvert* resource_base_convert);
ResourceBaseConvert* search(const std::string& tnn_op_name);

private:
static ResourceConvertManager* resource_convert_manager_;
std::map<std::string, ResourceBaseConvert*> resource_convert_map_;
};
template <class T>
class ResourceConvertRegister {
public:
explicit ResourceConvertRegister(const std::string& tnn_op_name) {
T* convert = new T;
ResourceConvertManager* resource_convert_manager = ResourceConvertManager::get();
resource_convert_manager->insert(tnn_op_name, convert);
}
};

#define REGISTER_RESOURCE_CONVERT(op_convert_name, tnn_op_name) \
ResourceConvertRegister<Resource##op_convert_name##Convert> g_resource_converter_##tnn_op_name##_(#tnn_op_name)

} // namespace TNN_CONVERTER

#define DECLARE_RESOURCE_CONVERT(op_convert_name) \
class Resource##op_convert_name##Convert : public ResourceBaseConvert { \
public: \
Resource##op_convert_name##Convert(){}; \
virtual ~Resource##op_convert_name##Convert(){}; \
virtual TNN_NS::Status ConvertToHalfResource(std::shared_ptr<TNN_NS::LayerParam> param, \
std::shared_ptr<TNN_NS::LayerResource> layer_resource); \
};

#endif // TNN_TOOLS_CONVERTER_SOURCE_RESOURCE_REOURCE_BASE_CONVERT_H_
35 changes: 35 additions & 0 deletions tools/converter/source/resource/resource_batchnorm_convert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/core/layer_type.h"
#include "tnn/interpreter/layer_resource.h"
#include "tnn/interpreter/raw_buffer.h"
#include "tnn/utils/half_utils.h"
#include "tools/converter/source/resource/reource_base_convert.h"

namespace TNN_CONVERTER {

DECLARE_RESOURCE_CONVERT(BatchNorm);

TNN_NS::Status ResourceBatchNormConvert::ConvertToHalfResource(std::shared_ptr<TNN_NS::LayerParam> param,
std::shared_ptr<TNN_NS::LayerResource> layer_resource) {
auto resource = std::dynamic_pointer_cast<TNN_NS::BatchNormLayerResource>(layer_resource);
resource->scale_handle = TNN_NS::ConvertFloatToFP16(resource->scale_handle);
resource->bias_handle = TNN_NS::ConvertFloatToFP16(resource->bias_handle);

return TNN_NS::TNN_CONVERT_OK;
}

REGISTER_RESOURCE_CONVERT(BatchNorm, BatchNormCxx);
} // namespace TNN_CONVERTER
45 changes: 45 additions & 0 deletions tools/converter/source/resource/resource_binary_convert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/core/layer_type.h"
#include "tnn/interpreter/layer_resource.h"
#include "tnn/interpreter/raw_buffer.h"
#include "tnn/utils/half_utils.h"
#include "tools/converter/source/resource/reource_base_convert.h"

namespace TNN_CONVERTER {

DECLARE_RESOURCE_CONVERT(Binary);

TNN_NS::Status ResourceBinaryConvert::ConvertToHalfResource(std::shared_ptr<TNN_NS::LayerParam> param,
std::shared_ptr<TNN_NS::LayerResource> layer_resource) {
auto resource = std::dynamic_pointer_cast<TNN_NS::EltwiseLayerResource>(layer_resource);
resource->element_handle = TNN_NS::ConvertFloatToFP16(resource->element_handle);

return TNN_NS::TNN_CONVERT_OK;
}

REGISTER_RESOURCE_CONVERT(Binary, Add);
REGISTER_RESOURCE_CONVERT(Binary, Sub);
REGISTER_RESOURCE_CONVERT(Binary, Mul);
REGISTER_RESOURCE_CONVERT(Binary, Div);
REGISTER_RESOURCE_CONVERT(Binary, Minimum);
REGISTER_RESOURCE_CONVERT(Binary, Maximum);
REGISTER_RESOURCE_CONVERT(Binary, Less);
REGISTER_RESOURCE_CONVERT(Binary, Greater);
REGISTER_RESOURCE_CONVERT(Binary, And);
REGISTER_RESOURCE_CONVERT(Binary, Not);
REGISTER_RESOURCE_CONVERT(Binary, Square);
REGISTER_RESOURCE_CONVERT(Binary, SquaredDifference);
} // namespace TNN_CONVERTER
34 changes: 34 additions & 0 deletions tools/converter/source/resource/resource_conv_convert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tnn/core/layer_type.h"
#include "tnn/interpreter/layer_resource.h"
#include "tnn/interpreter/raw_buffer.h"
#include "tnn/utils/half_utils.h"
#include "tools/converter/source/resource/reource_base_convert.h"

namespace TNN_CONVERTER {

DECLARE_RESOURCE_CONVERT(Conv);

TNN_NS::Status ResourceConvConvert::ConvertToHalfResource(std::shared_ptr<TNN_NS::LayerParam> param,
std::shared_ptr<TNN_NS::LayerResource> layer_resource) {
auto resource = std::dynamic_pointer_cast<TNN_NS::ConvLayerResource>(layer_resource);
resource->filter_handle = TNN_NS::ConvertFloatToFP16(resource->filter_handle);
resource->bias_handle = TNN_NS::ConvertFloatToFP16(resource->bias_handle);
return TNN_NS::TNN_CONVERT_OK;
}

REGISTER_RESOURCE_CONVERT(Conv, Convolution);
} // namespace TNN_CONVERTER
57 changes: 57 additions & 0 deletions tools/converter/source/resource/resource_convert.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "tools/converter/source/resource/resource_convert.h"

#include "tnn/core/common.h"
#include "tools/converter/source/resource/reource_base_convert.h"

namespace TNN_CONVERTER {

TNN_NS::Status ResourceConvert::SetResourceConvertType(ResourceConvertType resource_convert_type) {
this->resource_convert_type_ = resource_convert_type;
return TNN_NS::TNN_OK;
}
TNN_NS::Status ResourceConvert::converter(TNN_NS::NetStructure& net_structure, TNN_NS::NetResource& net_resource) {
if (resource_convert_type_ == RESOURCE_KEEP_ORIGINAL) {
return TNN_NS::TNN_OK;
}
auto& resource_map = net_resource.resource_map;
// convert float weight to half weight
if (resource_convert_type_ == RESOURCE_CONVERT_HALF) {
if (net_structure.layers.empty()) {
return TNN_NS::TNN_OK;
}
for (auto& layer : net_structure.layers) {
const std::string& layer_name = layer->name;
if (resource_map.find(layer_name) != resource_map.end() &&
resource_map.find(layer_name)->second != nullptr) {
const auto& convert = ResourceConvertManager::get()->search(layer->type_str);
if (convert == nullptr) {
LOGE("The ResourceConverter do not support layer:%s \n", layer->name.c_str());
LOGE("The unsupported operator type is:%s\n", layer->type_str.c_str());
return TNN_NS::TNNERR_CONVERT_UNSUPPORT_LAYER;
}
std::shared_ptr<TNN_NS::LayerResource> layer_resource = resource_map.find(layer_name)->second;
auto status = convert->ConvertToHalfResource(layer->param, layer_resource);
if (status != TNN_NS::TNN_CONVERT_OK) {
LOGE("ResourceConvert failed for %s\n", layer->name.c_str());
return status;
}
}
}
}
return TNN_NS::TNN_OK;
}
} // namespace TNN_CONVERTER
45 changes: 45 additions & 0 deletions tools/converter/source/resource/resource_convert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef TNN_TOOLS_CONVERTER_SOURCE_RESOURCE_RESOURCE_CONVERT_H_
#define TNN_TOOLS_CONVERTER_SOURCE_RESOURCE_RESOURCE_CONVERT_H_

#include "tnn/core/common.h"
#include "tnn/core/status.h"
#include "tnn/interpreter/net_resource.h"
#include "tnn/interpreter/net_structure.h"

namespace TNN_CONVERTER {

typedef enum {
RESOURCE_KEEP_ORIGINAL = 0,
RESOURCE_CONVERT_HALF = 1,
RESOURCE_CONVERT_FLOAT = 2,
RESOURCE_DYNAMIC_RANGE_QUANTIZATION = 3,
} ResourceConvertType;

class ResourceConvert {
public:
ResourceConvert() = default;
TNN_NS::Status SetResourceConvertType(ResourceConvertType resource_convert_type);
TNN_NS::Status converter(TNN_NS::NetStructure& net_structure, TNN_NS::NetResource& net_resource);

~ResourceConvert() = default;

private:
ResourceConvertType resource_convert_type_ = RESOURCE_KEEP_ORIGINAL;
};
} // namespace TNN_CONVERTER

#endif // TNN_TOOLS_CONVERTER_SOURCE_RESOURCE_RESOURCE_CONVERT_H_
Loading

0 comments on commit 30863dc

Please sign in to comment.