forked from mindspore-Ecosystem/mindspore
!16005 add tflite padv2 parser
Merge pull request !16005 from hangq/master
This commit is contained in:
commit
92d25d57d0
|
@ -726,6 +726,7 @@ TypeId GetParameterDtype(const ParameterPtr ¶m_node) {
|
|||
}
|
||||
|
||||
STATUS UpdateFuncGraphInputsAndOutputsDtype(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
// update graph inputs dtype
|
||||
size_t idx = 0;
|
||||
for (auto &input : func_graph->get_inputs()) {
|
||||
|
|
|
@ -37,19 +37,22 @@ FuncGraphPtr Converter::BuildFuncGraph(const converter::Flags &flag) {
|
|||
MindsporeImporter ms_import;
|
||||
func_graph = ms_import.ImportMindIR(flag);
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
model_parser_ = ModelParserRegistry::GetInstance()->GetModelParser(flag.fmkIn);
|
||||
if (model_parser_ == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:" << flag.fmkIn;
|
||||
return nullptr;
|
||||
}
|
||||
func_graph = model_parser_->Parse(flag);
|
||||
}
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Get funcGraph failed for fmk: " << flag.fmkIn;
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NOT_SUPPORT);
|
||||
return nullptr;
|
||||
}
|
||||
if (UpdateFuncGraphInputsAndOutputsDtype(func_graph) != RET_OK) {
|
||||
MS_LOG(ERROR) << "update graph inputs and outputs dtype failed.";
|
||||
MS_LOG(ERROR) << "Update graph inputs and outputs dtype failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph;
|
||||
|
|
|
@ -270,6 +270,15 @@ int Flags::Init(int argc, const char **argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if (save_fp16_str_ == "on") {
|
||||
save_fp16_ = true;
|
||||
} else if (save_fp16_str_ == "off") {
|
||||
save_fp16_ = false;
|
||||
} else {
|
||||
std::cerr << "Init save_fp16 failed.";
|
||||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
|
||||
ret = InitInputOutputDataType();
|
||||
if (ret != RET_OK) {
|
||||
std::cerr << "Init input output datatype failed.";
|
||||
|
|
|
@ -81,6 +81,8 @@ class Flags : public virtual mindspore::lite::FlagParser {
|
|||
std::string weightFile;
|
||||
TypeId inputDataType;
|
||||
TypeId outputDataType;
|
||||
std::string save_fp16_str_ = "off";
|
||||
bool save_fp16_ = false;
|
||||
// used for quantization
|
||||
std::string quantTypeStr;
|
||||
schema::QuantType quantType;
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "tools/converter/legacy_optimizer/graph/infer_quant_param_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/set_unused_quant_param_to_default_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/switch_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
|
||||
#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h"
|
||||
|
@ -186,6 +187,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
|||
forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
|
||||
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
|
||||
forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
|
||||
forming_model_optimizer.AddPass(new (std::nothrow) ConvertFP32ToFP16Pass(ctx.save_fp16_));
|
||||
status = forming_model_optimizer.Run(graph_defT_);
|
||||
if (status != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";
|
||||
|
|
|
@ -187,17 +187,20 @@ FuncGraphPtr MindsporeImporter::ImportMindIR(const converter::Flags &flag) {
|
|||
}
|
||||
if (func_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "get funcGraph failed for fmk:MINDIR";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return nullptr;
|
||||
}
|
||||
func_graph->set_attr("graph_name", MakeValue("main_graph"));
|
||||
func_graph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_MS)));
|
||||
if (Mindir2AnfAdjust(func_graph, flag) != RET_OK) {
|
||||
STATUS status;
|
||||
if ((status = Mindir2AnfAdjust(func_graph, flag)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Mindir2AnfAdjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
auto status = WeightFormatTransform(func_graph);
|
||||
if (status != RET_OK) {
|
||||
if ((status = WeightFormatTransform(func_graph)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return func_graph;
|
||||
|
|
|
@ -7,6 +7,7 @@ file(GLOB GRAPH_PASS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/convert_fp32_to_fp16_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/legacy_optimizer/graph/convert_fp32_to_fp16_pass.h"
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include "tools/converter/converter_context.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "tools/common/tensor_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "Eigen/Core"
|
||||
|
||||
using float16 = Eigen::half;
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
STATUS ConvertFP32ToFP16Pass::Run(schema::MetaGraphT *graph) {
|
||||
if (!need_convert_) {
|
||||
return RET_NO_CHANGE;
|
||||
}
|
||||
MS_ASSERT(graph != nullptr);
|
||||
bool if_changed = false;
|
||||
for (auto &tensor : graph->allTensors) {
|
||||
if (tensor->dataType != kNumberTypeFloat32 || tensor->data.empty()) {
|
||||
continue;
|
||||
}
|
||||
auto ele_num = lite::GetShapeSize(tensor->dims);
|
||||
auto origin_data = tensor->data;
|
||||
if (origin_data.size() != ele_num * sizeof(float) || origin_data.size() % 2 != 0) {
|
||||
MS_LOG(ERROR) << "Tensor data length error.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||
return RET_ERROR;
|
||||
}
|
||||
std::vector<uint8_t> new_data(origin_data.size() / 2);
|
||||
auto fp32_data = reinterpret_cast<float *>(origin_data.data());
|
||||
auto fp16_data = reinterpret_cast<float16 *>(new_data.data());
|
||||
for (size_t i = 0; i < ele_num; i++) {
|
||||
fp16_data[i] = float16(fp32_data[i]);
|
||||
}
|
||||
tensor->data.swap(new_data);
|
||||
tensor->dataType = kNumberTypeFloat16;
|
||||
new_data.clear();
|
||||
if_changed = true;
|
||||
}
|
||||
return if_changed ? RET_OK : RET_NO_CHANGE;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_CONVERT_FP32_TO_FP16_PASS_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_CONVERT_FP32_TO_FP16_PASS_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include "tools/converter/optimizer.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class ConvertFP32ToFP16Pass : public GraphPass {
|
||||
public:
|
||||
explicit ConvertFP32ToFP16Pass(bool save_fp16) : need_convert_(save_fp16) {}
|
||||
|
||||
~ConvertFP32ToFP16Pass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
bool need_convert_ = false;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_GRAGP_CONVERT_FP32_TO_FP16_PASS_H_
|
|
@ -97,12 +97,14 @@ FuncGraphPtr CaffeModelParser::Parse(const converter::Flags &flag) {
|
|||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
status = WeightFormatTransform(res_graph_);
|
||||
if (status != RET_OK) {
|
||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
|
|
|
@ -79,17 +79,19 @@ FuncGraphPtr OnnxModelParser::Parse(const converter::Flags &flag) {
|
|||
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
if (Onnx2AnfAdjust(all_func_graphs) != RET_OK) {
|
||||
if ((status = Onnx2AnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Onnx2AnfAdjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
status = WeightFormatTransform(all_func_graphs);
|
||||
if (status != RET_OK) {
|
||||
if ((status = WeightFormatTransform(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
|
|
|
@ -557,16 +557,19 @@ FuncGraphPtr TFModelParser::Parse(const converter::Flags &flag) {
|
|||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
|
||||
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
if (TF2AnfAdjust(all_func_graphs) != RET_OK) {
|
||||
if ((status = TF2AnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "TF2AnfAdjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
status = WeightFormatTransform(res_graph_);
|
||||
if (status != RET_OK) {
|
||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
res_graph_->set_manager(nullptr);
|
||||
|
|
|
@ -93,17 +93,19 @@ FuncGraphPtr TfliteModelParser::Parse(const converter::Flags &flag) {
|
|||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
GetAllFuncGraph(res_graph_, &all_func_graphs);
|
||||
|
||||
if (CommonAnfAdjust(all_func_graphs) != RET_OK) {
|
||||
if ((status = CommonAnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "AdjustForAnf failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
if (Tflite2AnfAdjust(all_func_graphs) != RET_OK) {
|
||||
if ((status = Tflite2AnfAdjust(all_func_graphs)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "Tflite2AnfAdjust failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
status = WeightFormatTransform(res_graph_);
|
||||
if (status != RET_OK) {
|
||||
if ((status = WeightFormatTransform(res_graph_)) != RET_OK) {
|
||||
MS_LOG(ERROR) << "WeightFormatTransform failed.";
|
||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||
return nullptr;
|
||||
}
|
||||
return res_graph_;
|
||||
|
@ -250,6 +252,7 @@ STATUS TfliteModelParser::ConvertOps() {
|
|||
if (node_parser == nullptr) {
|
||||
NotSupportOp::GetInstance()->InsertOp(op_type);
|
||||
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
|
||||
MS_LOG(ERROR) << "Can not find " << op_type << " op parser.";
|
||||
continue;
|
||||
}
|
||||
if (status != RET_OK) {
|
||||
|
|
|
@ -49,6 +49,20 @@ ops::PrimitiveC *TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
return nullptr;
|
||||
}
|
||||
prim->set_paddings(paddings);
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_PADV2) {
|
||||
prim->set_padding_mode(mindspore::PaddingMode::CONSTANT);
|
||||
if (tflite_op->inputs.size() < 3) {
|
||||
MS_LOG(ERROR) << "tflite padv2 input size less than 3, which is " << tflite_op->inputs.size();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<float> constant_value;
|
||||
auto ret = GetTfliteData(tflite_op->inputs.at(2), tflite_subgraph->tensors, tflite_model->buffers, constant_value);
|
||||
if (ret != RET_OK || constant_value.size() != 1) {
|
||||
MS_LOG(ERROR) << "get Pad -> constant_value failed";
|
||||
return nullptr;
|
||||
}
|
||||
prim->set_constant_value(constant_value.at(0));
|
||||
} else if (tflite_op_type == tflite::BuiltinOperator_MIRROR_PAD) {
|
||||
const auto &tflite_attr = tflite_op->builtin_options.AsMirrorPadOptions();
|
||||
if (tflite_attr == nullptr) {
|
||||
|
@ -75,6 +89,7 @@ ops::PrimitiveC *TflitePadParser::Parse(const std::unique_ptr<tflite::OperatorT>
|
|||
}
|
||||
|
||||
TfliteNodeRegister g_tflitePadParser(tflite::BuiltinOperator_PAD, new TflitePadParser());
|
||||
TfliteNodeRegister g_tflitePadV2Parser(tflite::BuiltinOperator_PADV2, new TflitePadParser());
|
||||
TfliteNodeRegister g_tfliteMirorPadParser(tflite::BuiltinOperator_MIRROR_PAD, new TflitePadParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue