!16005 add tflite padv2 parser

Merge pull request !16005 from hangq/master
This commit is contained in:
i-robot 2021-06-18 08:44:30 +00:00 committed by Gitee
commit 92d25d57d0
14 changed files with 169 additions and 21 deletions

View File

@ -726,6 +726,7 @@ TypeId GetParameterDtype(const ParameterPtr &param_node) {
}
STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
// update graph inputs dtype
size_t idx = 0;
for (auto &input : func_graph->get_inputs()) {

View File

@ -37,19 +37,22 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
MindsporeImporter ms_import;
func_graph = ms_import.ImportMindIR(flag);
if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
return nullptr;
}
} else {
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmkIn);
if (model_parser_ == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:" << flag.fmkIn;
return nullptr;
}
func_graph = model_parser_->Parse(flag);
}
if (func_graph == nullptr) {
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT);
return nullptr;
}
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
return nullptr;
}
return func_graph;

View File

@ -270,6 +270,15 @@ int Flags::Init(int argc, const char **argv) {
}
}
if (save_fp16_str_ == "on") {
save_fp16_ = true;
} else if (save_fp16_str_ == "off") {
save_fp16_ = false;
} else {
std::cerr << "Init save_fp16 failed.";
return RET_INPUT_PARAM_INVALID;
}
ret = InitInputOutputDataType();
if (ret != RET_OK) {
std::cerr << "Init input output datatype failed.";

View File

@ -81,6 +81,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
std::string weightFile;
TypeId inputDataType;
TypeId outputDataType;
std::string save_fp16_str_ = "off";
bool save_fp16_ = false;
// used for quantization
std::string quantTypeStr;
schema::QuantType quantType;

View File

@ -33,6 +33,7 @@
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
#include "tools/converter/legacy_optimizer/graph/switch_pass.h"
#include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h"
@ -186,6 +187,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.save_fp16_));
status = forming_model_optimizer.Run(graph_defT_);
if (status != RET_OK) {
MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";

View File

@ -187,17 +187,20 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
}
if (func_graph == nullptr) {
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
func_graph->set_attr("graph_name", MakeValue("main_graph"));
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
if (Mindir2AnfAdjust(func_graph, flag) != RET_OK) {
STATUS status;
if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
auto status = WeightFormatTransform(func_graph);
if (status != RET_OK) {
if ((status = WeightFormatTransform(func_graph)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return func_graph;

View File

@ -7,6 +7,7 @@ file(GLOB GRAPH_PASS
${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/convert_fp32_to_fp16_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
#include <queue>
#include <vector>
#include "tools/converter/converter_context.h"
#include "src/common/log_adapter.h"
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
#include "Eigen/Core"
using float16 = Eigen::half;
namespace mindspore {
namespace lite {
STATUS ConvertFP32ToFP16Pass::Run(schema::MetaGraphT *graph) {
if (!need_convert_) {
return RET_NO_CHANGE;
}
MS_ASSERT(graph != nullptr);
bool if_changed = false;
for (auto &tensor : graph->allTensors) {
if (tensor->dataType != kNumberTypeFloat32 || tensor->data.empty()) {
continue;
}
auto ele_num = lite::GetShapeSize(tensor->dims);
auto origin_data = tensor->data;
if (origin_data.size() != ele_num * sizeof(float) || origin_data.size() % 2 != 0) {
MS_LOG(ERROR) << "Tensor data length error.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return RET_ERROR;
}
std::vector<uint8_t> new_data(origin_data.size() / 2);
auto fp32_data = reinterpret_cast<float *>(origin_data.data());
auto fp16_data = reinterpret_cast<float16 *>(new_data.data());
for (size_t i = 0; i < ele_num; i++) {
fp16_data[i] = float16(fp32_data[i]);
}
tensor->data.swap(new_data);
tensor->dataType = kNumberTypeFloat16;
new_data.clear();
if_changed = true;
}
return if_changed ? RET_OK : RET_NO_CHANGE;
}
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_CONVERT_FP32_TO_FP16_PASS_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_CONVERT_FP32_TO_FP16_PASS_H_
#include <unordered_map>
#include "tools/converter/optimizer.h"
#include "tools/converter/converter_flags.h"
namespace mindspore {
namespace lite {
class ConvertFP32ToFP16Pass : public GraphPass {
public:
explicit ConvertFP32ToFP16Pass(bool save_fp16) : need_convert_(save_fp16) {}
~ConvertFP32ToFP16Pass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
private:
bool need_convert_ = false;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_CONVERT_FP32_TO_FP16_PASS_H_

View File

@ -97,12 +97,14 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) {
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = WeightFormatTransform(res_graph_);
if (status != RET_OK) {
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return res_graph_;

View File

@ -79,17 +79,19 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) {
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
if (Onnx2AnfAdjust(all_func_graphs) != RET_OK) {
if ((status = Onnx2AnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "Onnx2AnfAdjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = WeightFormatTransform(all_func_graphs);
if (status != RET_OK) {
if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return res_graph_;

View File

@ -557,16 +557,19 @@ FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) {
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
if (TF2AnfAdjust(all_func_graphs) != RET_OK) {
if ((status = TF2AnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "TF2AnfAdjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = WeightFormatTransform(res_graph_);
if (status != RET_OK) {
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
res_graph_->set_manager(nullptr);

View File

@ -93,17 +93,19 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) {
std::set<FuncGraphPtr> all_func_graphs = {};
GetAllFuncGraph(res_graph_, &all_func_graphs);
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "AdjustForAnf failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
if (Tflite2AnfAdjust(all_func_graphs) != RET_OK) {
if ((status = Tflite2AnfAdjust(all_func_graphs)) != RET_OK) {
MS_LOG(ERROR) << "Tflite2AnfAdjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
status = WeightFormatTransform(res_graph_);
if (status != RET_OK) {
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
MS_LOG(ERROR) << "WeightFormatTransform failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
}
return res_graph_;
@ -250,6 +252,7 @@ STATUS TfliteModelParser::ConvertOps() {
if (node_parser == nullptr) {
NotSupportOp::GetInstance()->InsertOp(op_type);
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
MS_LOG(ERROR) << "Can not find " << op_type << " op parser.";
continue;
}
if (status != RET_OK) {

View File

@ -49,6 +49,20 @@ ops::PrimitiveC *TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT>
return nullptr;
}
prim->set_paddings(paddings);
} else if (tflite_op_type == tflite::BuiltinOperator_PADV2) {
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT);
if (tflite_op->inputs.size() < 3) {
MS_LOG(ERROR) << "tflite padv2 input size less than 3, which is " << tflite_op->inputs.size();
return nullptr;
}
std::vector<float> constant_value;
auto ret = GetTfliteData(tflite_op->inputs.at(2), tflite_subgraph->tensors, tflite_model->buffers, constant_value);
if (ret != RET_OK || constant_value.size() != 1) {
MS_LOG(ERROR) << "get Pad -> constant_value failed";
return nullptr;
}
prim->set_constant_value(constant_value.at(0));
} else if (tflite_op_type == tflite::BuiltinOperator_MIRROR_PAD) {
const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions();
if (tflite_attr == nullptr) {
@ -75,6 +89,7 @@ ops::PrimitiveC *TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT>
}
TfliteNodeRegister g_tflitePadParser(tflite::BuiltinOperator_PAD, new TflitePadParser());
TfliteNodeRegister g_tflitePadV2Parser(tflite::BuiltinOperator_PADV2, new TflitePadParser());
TfliteNodeRegister g_tfliteMirorPadParser(tflite::BuiltinOperator_MIRROR_PAD, new TflitePadParser());
} // namespace lite
} // namespace mindspore