forked from mindspore-Ecosystem/mindspore
!15228 fix abstract of parameter
From: @hangangqiang Reviewed-by: @zhang_xue_tong,@ddwsky Signed-off-by: @zhang_xue_tong
This commit is contained in:
commit
1a9463c513
|
@ -38,9 +38,12 @@ int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o
|
||||||
if (!parameter->infer_flag_) {
|
if (!parameter->infer_flag_) {
|
||||||
return NNACL_INFER_INVALID;
|
return NNACL_INFER_INVALID;
|
||||||
}
|
}
|
||||||
|
const TensorC *input_k_tensor = inputs[1];
|
||||||
|
if (input_k_tensor->data_ == NULL) {
|
||||||
|
return NNACL_INFER_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
TopkParameter *param = (TopkParameter *)parameter;
|
TopkParameter *param = (TopkParameter *)parameter;
|
||||||
const TensorC *input_k_tensor = inputs[1];
|
|
||||||
param->k_ = ((int32_t *)input_k_tensor->data_)[0];
|
param->k_ = ((int32_t *)input_k_tensor->data_)[0];
|
||||||
|
|
||||||
int out_shape[MAX_SHAPE_SIZE];
|
int out_shape[MAX_SHAPE_SIZE];
|
||||||
|
|
|
@ -75,12 +75,14 @@ class LiteModel : public Model {
|
||||||
} else {
|
} else {
|
||||||
node->name_ = c_node->name()->c_str();
|
node->name_ = c_node->name()->c_str();
|
||||||
}
|
}
|
||||||
auto count = c_node->inputIndex()->size();
|
if (c_node->inputIndex() != nullptr) {
|
||||||
for (uint32_t j = 0; j < count; ++j) {
|
auto count = c_node->inputIndex()->size();
|
||||||
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
|
for (uint32_t j = 0; j < count; ++j) {
|
||||||
|
node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (c_node->outputIndex() != nullptr) {
|
if (c_node->outputIndex() != nullptr) {
|
||||||
count = c_node->outputIndex()->size();
|
auto count = c_node->outputIndex()->size();
|
||||||
for (uint32_t j = 0; j < count; ++j) {
|
for (uint32_t j = 0; j < count; ++j) {
|
||||||
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
|
node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -247,6 +247,7 @@ if(ENABLE_CONVERTER)
|
||||||
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc
|
||||||
|
${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc
|
||||||
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
|
${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc
|
||||||
|
|
|
@ -73,6 +73,40 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std
|
||||||
return tensor_info;
|
return tensor_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type) {
|
||||||
|
auto tensor_info = CreateTensorInfo(nullptr, 0, shape, data_type);
|
||||||
|
if (tensor_info == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor info failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto abstract = tensor_info->ToAbstract();
|
||||||
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return abstract;
|
||||||
|
}
|
||||||
|
|
||||||
|
int SetParameterAbstractAndParam(const ParameterPtr ¶meter, const void *data, size_t data_size,
|
||||||
|
const std::vector<int64_t> &shape, TypeId data_type) {
|
||||||
|
if (parameter == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Input parameter is nullptr";
|
||||||
|
return RET_INPUT_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type);
|
||||||
|
if (tensor_info == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor info failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto abstract = tensor_info->ToAbstract();
|
||||||
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
parameter->set_abstract(abstract);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) {
|
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) {
|
||||||
if (tensor_info == nullptr) {
|
if (tensor_info == nullptr) {
|
||||||
MS_LOG(ERROR) << "tensor info is nullptr.";
|
MS_LOG(ERROR) << "tensor info is nullptr.";
|
||||||
|
|
|
@ -46,6 +46,11 @@ std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT>
|
||||||
tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
|
tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
|
||||||
TypeId data_type);
|
TypeId data_type);
|
||||||
|
|
||||||
|
AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type);
|
||||||
|
|
||||||
|
int SetParameterAbstractAndParam(const ParameterPtr ¶meter, const void *data, size_t data_size,
|
||||||
|
const std::vector<int64_t> &shape, TypeId data_type);
|
||||||
|
|
||||||
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size);
|
int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size);
|
||||||
|
|
||||||
std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info,
|
std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info,
|
||||||
|
|
|
@ -54,6 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
../optimizer/fusion/tf_bidirection_gru_fusion.cc
|
../optimizer/fusion/tf_bidirection_gru_fusion.cc
|
||||||
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc
|
||||||
../optimizer/fusion/matmul_add_fusion.cc
|
../optimizer/fusion/matmul_add_fusion.cc
|
||||||
|
../optimizer/fusion/mul_add_fusion.cc
|
||||||
../optimizer/fusion/gelu_fusion.cc
|
../optimizer/fusion/gelu_fusion.cc
|
||||||
../optimizer/fusion/tf_gelu_fusion.cc
|
../optimizer/fusion/tf_gelu_fusion.cc
|
||||||
../optimizer/fusion/onnx_gelu_fusion.cc
|
../optimizer/fusion/onnx_gelu_fusion.cc
|
||||||
|
|
|
@ -70,7 +70,7 @@ MetaGraphT *Converter::Convert(const std::unique_ptr<converter::Flags> &flag) {
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "Run anfTransform success";
|
MS_LOG(INFO) << "Run anfTransform success";
|
||||||
|
|
||||||
// protobuf -> flatbuf
|
// protobuf -> flatbuffer
|
||||||
auto meta_graph = Export(graph, false, false, flag->trainModel);
|
auto meta_graph = Export(graph, false, false, flag->trainModel);
|
||||||
if (meta_graph == nullptr) {
|
if (meta_graph == nullptr) {
|
||||||
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
MS_LOG(ERROR) << "Export to meta graph return nullptr";
|
||||||
|
|
|
@ -39,7 +39,6 @@
|
||||||
|
|
||||||
using std::string;
|
using std::string;
|
||||||
namespace mindspore::lite {
|
namespace mindspore::lite {
|
||||||
|
|
||||||
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
|
std::vector<schema::CNodeT *> GraphDefTransform::GetGraphNodes() {
|
||||||
std::vector<schema::CNodeT *> old_nodes{};
|
std::vector<schema::CNodeT *> old_nodes{};
|
||||||
old_nodes.resize(graph_defT_->nodes.size());
|
old_nodes.resize(graph_defT_->nodes.size());
|
||||||
|
@ -71,54 +70,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate and infer quant parameters
|
// format transpose global optimize
|
||||||
{
|
|
||||||
Optimizer infer_quant_param_pass;
|
|
||||||
infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass());
|
|
||||||
infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass());
|
|
||||||
status = infer_quant_param_pass.Run(graph_defT_);
|
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
|
||||||
MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
// format transform
|
|
||||||
// init old node indices
|
|
||||||
auto old_nodes = GetGraphNodes();
|
|
||||||
|
|
||||||
Optimizer format_trans_optimizer;
|
|
||||||
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
|
||||||
format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
|
||||||
if (ctx.fmk != converter::FmkType_TF) {
|
|
||||||
auto infer_shape_pass = new (std::nothrow) InferShapePass();
|
|
||||||
if (infer_shape_pass == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new InferShapePass failed";
|
|
||||||
return RET_MEMORY_FAILED;
|
|
||||||
}
|
|
||||||
infer_shape_pass->set_fmk_type(ctx.fmk);
|
|
||||||
format_trans_optimizer.AddPass(infer_shape_pass);
|
|
||||||
}
|
|
||||||
status = format_trans_optimizer.Run(graph_defT_);
|
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
|
||||||
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
{
|
|
||||||
// init old node indices
|
|
||||||
auto old_nodes = GetGraphNodes();
|
|
||||||
Optimizer format_trans_optimizer;
|
|
||||||
format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
|
||||||
format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
|
||||||
status = format_trans_optimizer.Run(graph_defT_);
|
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) {
|
|
||||||
MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
// init old node indices
|
// init old node indices
|
||||||
auto old_nodes = GetGraphNodes();
|
auto old_nodes = GetGraphNodes();
|
||||||
|
@ -134,20 +86,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// postconvert pass
|
// node replace
|
||||||
{
|
if (!ctx.trainModel) {
|
||||||
// init old node indices
|
// init old node indices
|
||||||
auto old_nodes = GetGraphNodes();
|
auto old_nodes = GetGraphNodes();
|
||||||
Optimizer replace_optimizer;
|
Optimizer replace_optimizer;
|
||||||
if (!ctx.trainModel) {
|
replace_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
|
||||||
auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass();
|
replace_optimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass(ctx.fmk));
|
||||||
if (batch_norm_scale_pass == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new batch_norm_scale_pass failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
batch_norm_scale_pass->SetFmk(ctx.fmk);
|
|
||||||
replace_optimizer.AddPass(batch_norm_scale_pass);
|
|
||||||
}
|
|
||||||
replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||||
replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
|
replace_optimizer.AddPass(new SubgraphNodePass(old_nodes));
|
||||||
status = replace_optimizer.Run(graph_defT_);
|
status = replace_optimizer.Run(graph_defT_);
|
||||||
|
@ -157,6 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// node fusion
|
||||||
{
|
{
|
||||||
// init old node indices
|
// init old node indices
|
||||||
auto old_nodes = GetGraphNodes();
|
auto old_nodes = GetGraphNodes();
|
||||||
|
@ -171,19 +117,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// do quantization
|
// quantization
|
||||||
if (ctx.fmk != converter::FmkType_TF) {
|
if (ctx.fmk != converter::FmkType_TF) {
|
||||||
// init old node indices
|
// init old node indices
|
||||||
auto old_nodes = GetGraphNodes();
|
auto old_nodes = GetGraphNodes();
|
||||||
Optimizer tensor_quant_optimizer;
|
Optimizer tensor_quant_optimizer;
|
||||||
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||||
auto infer_shape_pass = new (std::nothrow) InferShapePass();
|
tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
|
||||||
if (infer_shape_pass == nullptr) {
|
tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
|
||||||
MS_LOG(ERROR) << "new InferShapePass failed";
|
|
||||||
return RET_MEMORY_FAILED;
|
|
||||||
}
|
|
||||||
infer_shape_pass->set_fmk_type(ctx.fmk);
|
|
||||||
tensor_quant_optimizer.AddPass(infer_shape_pass);
|
|
||||||
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
|
tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass());
|
||||||
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||||
status = tensor_quant_optimizer.Run(graph_defT_);
|
status = tensor_quant_optimizer.Run(graph_defT_);
|
||||||
|
@ -193,38 +134,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert quantNode and deQuantNode
|
// quantization
|
||||||
if (ctx.fmk != converter::FmkType_TF) {
|
if (ctx.fmk != converter::FmkType_TF) {
|
||||||
// init old node indices
|
// init old node indices
|
||||||
auto old_nodes = GetGraphNodes();
|
|
||||||
Optimizer quant_node_optimizer;
|
Optimizer quant_node_optimizer;
|
||||||
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
|
||||||
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||||
auto infer_shape_pass = new (std::nothrow) InferShapePass();
|
auto old_nodes = GetGraphNodes();
|
||||||
if (infer_shape_pass == nullptr) {
|
quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
|
||||||
MS_LOG(ERROR) << "new InferShapePass failed";
|
quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType));
|
||||||
return RET_MEMORY_FAILED;
|
|
||||||
}
|
|
||||||
infer_shape_pass->set_fmk_type(ctx.fmk);
|
|
||||||
quant_node_optimizer.AddPass(infer_shape_pass);
|
|
||||||
status = quant_node_optimizer.Run(graph_defT_);
|
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
|
||||||
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
auto old_nodes2 = GetGraphNodes();
|
|
||||||
quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass());
|
|
||||||
auto dtype_trans_pass = new (std::nothrow) DTypeTransPass();
|
|
||||||
if (dtype_trans_pass == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new dtype_trans_pass failed";
|
|
||||||
return RET_MEMORY_FAILED;
|
|
||||||
}
|
|
||||||
dtype_trans_pass->set_input_data_dtype(ctx.inputDataType);
|
|
||||||
dtype_trans_pass->set_output_data_dtype(ctx.outputDataType);
|
|
||||||
quant_node_optimizer.AddPass(dtype_trans_pass);
|
|
||||||
quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
|
quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
|
||||||
quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||||
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2));
|
quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||||
status = quant_node_optimizer.Run(graph_defT_);
|
status = quant_node_optimizer.Run(graph_defT_);
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||||
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
|
MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed";
|
||||||
|
@ -232,7 +152,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// switch pass
|
// controlflow pass
|
||||||
{
|
{
|
||||||
// init old node indices
|
// init old node indices
|
||||||
auto old_nodes = GetGraphNodes();
|
auto old_nodes = GetGraphNodes();
|
||||||
|
@ -240,6 +160,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
switch_optimizer.AddPass(new (std::nothrow) SwitchPass());
|
switch_optimizer.AddPass(new (std::nothrow) SwitchPass());
|
||||||
switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
|
||||||
switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||||
|
switch_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
|
||||||
status = switch_optimizer.Run(graph_defT_);
|
status = switch_optimizer.Run(graph_defT_);
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||||
MS_LOG(ERROR) << "Run switch_optimizer Failed";
|
MS_LOG(ERROR) << "Run switch_optimizer Failed";
|
||||||
|
@ -247,34 +168,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// subgraph tensor pass
|
|
||||||
{
|
|
||||||
Optimizer subgraph_tensor_optimizer;
|
|
||||||
subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass());
|
|
||||||
status = subgraph_tensor_optimizer.Run(graph_defT_);
|
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
|
||||||
MS_LOG(ERROR) << "Run subgraph tensor pass Failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// tensor name
|
|
||||||
{
|
|
||||||
// init old node indices
|
|
||||||
auto old_nodes = GetGraphNodes();
|
|
||||||
Optimizer name_optimizer;
|
|
||||||
name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
|
||||||
name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
|
||||||
name_optimizer.AddPass(new (std::nothrow) TensorNamePass());
|
|
||||||
status = name_optimizer.Run(graph_defT_);
|
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
|
||||||
MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed";
|
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
{
|
||||||
Optimizer nested_loop_optimizer;
|
Optimizer nested_loop_optimizer;
|
||||||
|
auto old_nodes = GetGraphNodes();
|
||||||
|
nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes));
|
||||||
|
nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass());
|
||||||
nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
|
nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass());
|
||||||
status = nested_loop_optimizer.Run(graph_defT_);
|
status = nested_loop_optimizer.Run(graph_defT_);
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
if (status != RET_OK && status != RET_NO_CHANGE) {
|
||||||
|
@ -284,30 +182,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
Optimizer quant_param_optimizer;
|
Optimizer forming_model_optimizer;
|
||||||
quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
|
forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk));
|
||||||
status = quant_param_optimizer.Run(graph_defT_);
|
forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass());
|
||||||
if (status != RET_OK && status != RET_NO_CHANGE) {
|
forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass());
|
||||||
MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed";
|
status = forming_model_optimizer.Run(graph_defT_);
|
||||||
return status;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
Optimizer infer_shape_optimizer;
|
|
||||||
auto infer_shape_pass = new (std::nothrow) InferShapePass();
|
|
||||||
if (infer_shape_pass == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new InferShapePass failed";
|
|
||||||
return RET_MEMORY_FAILED;
|
|
||||||
}
|
|
||||||
infer_shape_pass->set_fmk_type(ctx.fmk);
|
|
||||||
infer_shape_optimizer.AddPass(infer_shape_pass);
|
|
||||||
status = infer_shape_optimizer.Run(graph_defT_);
|
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";
|
MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed.";
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
} // namespace mindspore::lite
|
}
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -36,14 +36,12 @@ struct BNWeightTensors {
|
||||||
};
|
};
|
||||||
class BatchNormConvertScalePass : public GraphPass {
|
class BatchNormConvertScalePass : public GraphPass {
|
||||||
public:
|
public:
|
||||||
BatchNormConvertScalePass() = default;
|
explicit BatchNormConvertScalePass(converter::FmkType fmk) : fmkType(fmk) {}
|
||||||
|
|
||||||
~BatchNormConvertScalePass() = default;
|
~BatchNormConvertScalePass() = default;
|
||||||
|
|
||||||
STATUS Run(MetaGraphT *graph) override;
|
STATUS Run(MetaGraphT *graph) override;
|
||||||
|
|
||||||
void SetFmk(converter::FmkType fmk) { this->fmkType = fmk; }
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode);
|
STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr<CNodeT> &bnNode);
|
||||||
|
|
||||||
|
|
|
@ -276,10 +276,5 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte
|
||||||
return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
|
return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num,
|
||||||
castOpCopyer);
|
castOpCopyer);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; }
|
|
||||||
|
|
||||||
void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; }
|
|
||||||
|
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,16 +30,13 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 }
|
||||||
|
|
||||||
class DTypeTransPass : public GraphPass {
|
class DTypeTransPass : public GraphPass {
|
||||||
public:
|
public:
|
||||||
DTypeTransPass() : id_(0) {}
|
DTypeTransPass(TypeId model_input_data_type, TypeId model_output_data_type)
|
||||||
|
: id_(0), input_data_dtype(model_input_data_type), output_data_dtype(model_output_data_type) {}
|
||||||
|
|
||||||
~DTypeTransPass() override = default;
|
~DTypeTransPass() override = default;
|
||||||
|
|
||||||
STATUS Run(schema::MetaGraphT *graph) override;
|
STATUS Run(schema::MetaGraphT *graph) override;
|
||||||
|
|
||||||
void set_input_data_dtype(TypeId data_type);
|
|
||||||
|
|
||||||
void set_output_data_dtype(TypeId dataType);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
|
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
|
||||||
|
|
||||||
|
|
|
@ -39,14 +39,10 @@ struct InferTensor {
|
||||||
|
|
||||||
class InferShapePass : public GraphPass {
|
class InferShapePass : public GraphPass {
|
||||||
public:
|
public:
|
||||||
InferShapePass() = default;
|
explicit InferShapePass(converter::FmkType fmk_type) : fmk_type_(fmk_type) {}
|
||||||
|
~InferShapePass() override = default;
|
||||||
~InferShapePass() = default;
|
|
||||||
|
|
||||||
STATUS Run(MetaGraphT *graph) override;
|
STATUS Run(MetaGraphT *graph) override;
|
||||||
|
|
||||||
void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitSearchTensor(MetaGraphT *graph);
|
void InitSearchTensor(MetaGraphT *graph);
|
||||||
void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index);
|
void AddNextInferShapeNode(std::vector<uint32_t> output_tensor_node_indexes, size_t index);
|
||||||
|
|
|
@ -34,8 +34,28 @@ class ModelParser {
|
||||||
|
|
||||||
virtual ~ModelParser() = default;
|
virtual ~ModelParser() = default;
|
||||||
|
|
||||||
virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
|
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) {
|
||||||
const QuantType &quant_type) = 0;
|
auto ret = ParseToFuncGraph(model_file, weight_file, quant_type);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Parse to func graph failed : " << ret;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ret = PostAdjust();
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Adjust func graph failed : " << ret;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return this->res_graph_;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||||
|
const QuantType &quant_type) = 0;
|
||||||
|
|
||||||
|
virtual int PostAdjust() = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
FuncGraphPtr res_graph_ = nullptr;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
#include "tools/converter/ops/while.h"
|
#include "tools/converter/ops/while.h"
|
||||||
#include "utils/check_convert_utils.h"
|
#include "utils/check_convert_utils.h"
|
||||||
#include "abstract/primitive_infer_map.h"
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
@ -55,7 +56,9 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
||||||
AbstractBasePtrList output;
|
AbstractBasePtrList output;
|
||||||
for (int64_t i = 0; i < (int64_t)input_args.size(); i++) {
|
for (int64_t i = 0; i < (int64_t)input_args.size(); i++) {
|
||||||
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
|
auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape];
|
||||||
output.push_back(std::make_shared<abstract::AbstractTensor>(input_args[i]->BuildType(), shape));
|
auto abstract_tensor = lite::CreateTensorAbstract(shape, input_args[i]->BuildType()->type_id());
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract_tensor);
|
||||||
|
output.push_back(abstract_tensor);
|
||||||
}
|
}
|
||||||
return std::make_shared<abstract::AbstractTuple>(output);
|
return std::make_shared<abstract::AbstractTuple>(output);
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,34 +41,34 @@ CaffeModelParser::CaffeModelParser() = default;
|
||||||
|
|
||||||
CaffeModelParser::~CaffeModelParser() = default;
|
CaffeModelParser::~CaffeModelParser() = default;
|
||||||
|
|
||||||
FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::string &weight_file,
|
int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||||
const QuantType &quant_type) {
|
const QuantType &quant_type) {
|
||||||
STATUS status = InitOriginModel(model_file, weight_file);
|
STATUS status = InitOriginModel(model_file, weight_file);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
func_graph_ptr_ = std::make_shared<FuncGraph>();
|
res_graph_ = std::make_shared<FuncGraph>();
|
||||||
status = ConvertGraphInputs();
|
status = ConvertGraphInputs();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = ConvertLayers();
|
status = ConvertLayers();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = ConvertGraphOutputs();
|
status = ConvertGraphOutputs();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph"));
|
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||||
func_graph_ptr_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
|
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_CAFFE)));
|
||||||
return func_graph_ptr_;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS CaffeModelParser::ConvertLayers() {
|
STATUS CaffeModelParser::ConvertLayers() {
|
||||||
|
@ -134,7 +134,7 @@ STATUS CaffeModelParser::ConvertLayers() {
|
||||||
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))};
|
std::vector<AnfNodePtr> op_inputs = {NewValueNode(std::shared_ptr<ops::PrimitiveC>(primitive_c))};
|
||||||
op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end());
|
op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end());
|
||||||
op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end());
|
op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end());
|
||||||
auto new_cnode = func_graph_ptr_->NewCNode(op_inputs);
|
auto new_cnode = res_graph_->NewCNode(op_inputs);
|
||||||
new_cnode->set_fullname_with_scope(layer.name());
|
new_cnode->set_fullname_with_scope(layer.name());
|
||||||
|
|
||||||
// convert outputs
|
// convert outputs
|
||||||
|
@ -194,14 +194,17 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
|
||||||
for (int i = 0; i < caffe_model_.layer_size(); i++) {
|
for (int i = 0; i < caffe_model_.layer_size(); i++) {
|
||||||
auto layer = caffe_model_.layer(i);
|
auto layer = caffe_model_.layer(i);
|
||||||
if (layer.type() == "Input") {
|
if (layer.type() == "Input") {
|
||||||
auto parameter = func_graph_ptr_->add_parameter();
|
auto parameter = res_graph_->add_parameter();
|
||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) {
|
for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) {
|
||||||
shape.push_back(layer.input_param().shape(0).dim(j));
|
shape.push_back(layer.input_param().shape(0).dim(j));
|
||||||
}
|
}
|
||||||
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
|
auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
if (abstract == nullptr) {
|
||||||
parameter->set_abstract(abstract_tensor);
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
parameter->set_abstract(abstract);
|
||||||
parameter->set_name("graph_input-" + std::to_string(i));
|
parameter->set_name("graph_input-" + std::to_string(i));
|
||||||
nodes_.insert(std::pair(layer.top(0), parameter));
|
nodes_.insert(std::pair(layer.top(0), parameter));
|
||||||
}
|
}
|
||||||
|
@ -220,10 +223,13 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
|
||||||
shape.push_back(caffe_model_.input_dim(j));
|
shape.push_back(caffe_model_.input_dim(j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto parameter = func_graph_ptr_->add_parameter();
|
auto parameter = res_graph_->add_parameter();
|
||||||
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
|
auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape);
|
if (abstract == nullptr) {
|
||||||
parameter->set_abstract(abstract_tensor);
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
parameter->set_abstract(abstract);
|
||||||
parameter->set_name("graph_input-" + caffe_model_.input(i));
|
parameter->set_name("graph_input-" + caffe_model_.input(i));
|
||||||
nodes_.insert(std::pair(caffe_model_.input(i), parameter));
|
nodes_.insert(std::pair(caffe_model_.input(i), parameter));
|
||||||
}
|
}
|
||||||
|
@ -234,10 +240,18 @@ STATUS CaffeModelParser::ConvertGraphInputs() {
|
||||||
for (int j = 0; j < shape.dim_size(); j++) {
|
for (int j = 0; j < shape.dim_size(); j++) {
|
||||||
shape_vector.push_back(shape.dim(j));
|
shape_vector.push_back(shape.dim(j));
|
||||||
}
|
}
|
||||||
auto parameter = func_graph_ptr_->add_parameter();
|
auto parameter = res_graph_->add_parameter();
|
||||||
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
|
auto tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeFloat32);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
if (tensor_info == nullptr) {
|
||||||
parameter->set_abstract(abstract_tensor);
|
MS_LOG(ERROR) << "Create tensor info failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto abstract = tensor_info->ToAbstract();
|
||||||
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
parameter->set_abstract(abstract);
|
||||||
parameter->set_name("graph_input-" + caffe_model_.input(i));
|
parameter->set_name("graph_input-" + caffe_model_.input(i));
|
||||||
nodes_.insert(std::pair(caffe_model_.input(i), parameter));
|
nodes_.insert(std::pair(caffe_model_.input(i), parameter));
|
||||||
}
|
}
|
||||||
|
@ -265,7 +279,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
auto cnode = nodes_.find(output_node)->second;
|
auto cnode = nodes_.find(output_node)->second;
|
||||||
make_tuple_inputs.emplace_back(cnode);
|
make_tuple_inputs.emplace_back(cnode);
|
||||||
}
|
}
|
||||||
auto make_tuple_cnode = func_graph_ptr_->NewCNode(make_tuple_inputs);
|
auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs);
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
|
|
||||||
std::vector<AnfNodePtr> op_inputs;
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
|
@ -277,9 +291,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
auto value_node = NewValueNode(return_prim_ptr);
|
auto value_node = NewValueNode(return_prim_ptr);
|
||||||
op_inputs.emplace_back(value_node);
|
op_inputs.emplace_back(value_node);
|
||||||
op_inputs.emplace_back(make_tuple_cnode);
|
op_inputs.emplace_back(make_tuple_cnode);
|
||||||
auto cnode = func_graph_ptr_->NewCNode(op_inputs);
|
auto cnode = res_graph_->NewCNode(op_inputs);
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
func_graph_ptr_->set_return(cnode);
|
res_graph_->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto returnPrim = std::make_shared<ops::Return>();
|
auto returnPrim = std::make_shared<ops::Return>();
|
||||||
if (returnPrim == nullptr) {
|
if (returnPrim == nullptr) {
|
||||||
|
@ -298,9 +312,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() {
|
||||||
return RET_NOT_FIND_OP;
|
return RET_NOT_FIND_OP;
|
||||||
}
|
}
|
||||||
opInputs.emplace_back(cnode);
|
opInputs.emplace_back(cnode);
|
||||||
auto returnCnode = func_graph_ptr_->NewCNode(opInputs);
|
auto returnCnode = res_graph_->NewCNode(opInputs);
|
||||||
returnCnode->set_fullname_with_scope("Return");
|
returnCnode->set_fullname_with_scope("Return");
|
||||||
func_graph_ptr_->set_return(returnCnode);
|
res_graph_->set_return(returnCnode);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -333,7 +347,7 @@ STATUS CaffeModelParser::ConvertBlobs(const caffe::LayerParameter &layer, std::v
|
||||||
ConvertShape(layer.blobs(i), &shape);
|
ConvertShape(layer.blobs(i), &shape);
|
||||||
|
|
||||||
// cal Weight num
|
// cal Weight num
|
||||||
auto parameter = func_graph_ptr_->add_parameter();
|
auto parameter = res_graph_->add_parameter();
|
||||||
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
|
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
|
||||||
std::vector<int64_t> shape_vector;
|
std::vector<int64_t> shape_vector;
|
||||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
||||||
|
@ -402,17 +416,25 @@ STATUS CaffeModelParser::ConvertBottom(const caffe::LayerParameter &layer, std::
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) {
|
STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) {
|
||||||
auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32);
|
|
||||||
std::vector<int64_t> shape_vector;
|
|
||||||
if (layer.top_size() == 1) {
|
if (layer.top_size() == 1) {
|
||||||
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
cnode->set_abstract(abstract);
|
||||||
nodes_[layer.top(0)] = cnode;
|
nodes_[layer.top(0)] = cnode;
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtrList abstract_list;
|
AbstractBasePtrList abstract_list;
|
||||||
for (int i = 0; i < layer.top_size(); i++) {
|
for (int i = 0; i < layer.top_size(); i++) {
|
||||||
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
abstract_list.emplace_back(abstract);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
|
@ -421,7 +443,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN
|
||||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||||
auto get_item_value = NewValueNode(MakeValue<int>(i));
|
auto get_item_value = NewValueNode(MakeValue<int>(i));
|
||||||
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
|
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, cnode, get_item_value};
|
||||||
CNodePtr get_item_cnode = func_graph_ptr_->NewCNode(inputs);
|
CNodePtr get_item_cnode = res_graph_->NewCNode(inputs);
|
||||||
get_item_cnode->set_fullname_with_scope(layer.top(i));
|
get_item_cnode->set_fullname_with_scope(layer.top(i));
|
||||||
nodes_[layer.top(i)] = get_item_cnode;
|
nodes_[layer.top(i)] = get_item_cnode;
|
||||||
}
|
}
|
||||||
|
@ -446,4 +468,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name)
|
||||||
}
|
}
|
||||||
return layer.name();
|
return layer.name();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int CaffeModelParser::PostAdjust() { return RET_OK; }
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -32,8 +32,10 @@ class CaffeModelParser : public ModelParser {
|
||||||
|
|
||||||
~CaffeModelParser() override;
|
~CaffeModelParser() override;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
|
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||||
const QuantType &quant_type) override;
|
const QuantType &quant_type) override;
|
||||||
|
|
||||||
|
int PostAdjust() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
|
STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file);
|
||||||
|
@ -59,7 +61,6 @@ class CaffeModelParser : public ModelParser {
|
||||||
caffe::NetParameter caffe_weight_;
|
caffe::NetParameter caffe_weight_;
|
||||||
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
|
std::unordered_map<std::string, caffe::LayerParameter> caffe_layers_;
|
||||||
std::unordered_map<std::string, AnfNodePtr> nodes_;
|
std::unordered_map<std::string, AnfNodePtr> nodes_;
|
||||||
FuncGraphPtr func_graph_ptr_;
|
|
||||||
};
|
};
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
||||||
|
|
|
@ -45,31 +45,31 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
||||||
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
||||||
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
||||||
|
|
||||||
FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file,
|
int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||||
const QuantType &quant_type) {
|
const QuantType &quant_type) {
|
||||||
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
NotSupportOp::GetInstance()->set_fmk_type("ONNX");
|
||||||
anf_root_graph_ = std::make_shared<FuncGraph>();
|
res_graph_ = std::make_shared<FuncGraph>();
|
||||||
auto status = InitOriginModel(model_file);
|
auto status = InitOriginModel(model_file);
|
||||||
if (RET_OK != status) {
|
if (RET_OK != status) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
MS_LOG(ERROR) << "init origin model failed.";
|
MS_LOG(ERROR) << "init origin model failed.";
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node");
|
status = ConvertOnnxGraph(onnx_root_graph_, res_graph_, &anf_nodes_map_, {}, "root_node");
|
||||||
if (RET_OK != status) {
|
if (RET_OK != status) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
MS_LOG(ERROR) << "convert onnx graph failed.";
|
MS_LOG(ERROR) << "convert onnx graph failed.";
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
static auto root_func_manager = Manage(anf_root_graph_);
|
static auto root_func_manager = Manage(res_graph_);
|
||||||
for (auto &subgraph : all_subgraphs_) {
|
for (auto &subgraph : all_subgraphs_) {
|
||||||
subgraph->set_manager(root_func_manager);
|
subgraph->set_manager(root_func_manager);
|
||||||
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
subgraph->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||||
}
|
}
|
||||||
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||||
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||||
return anf_root_graph_;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
|
STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
|
||||||
|
@ -88,9 +88,9 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) {
|
||||||
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
|
OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version());
|
||||||
onnx_root_graph_ = onnx_model_.graph();
|
onnx_root_graph_ = onnx_model_.graph();
|
||||||
if (OnnxNodeParser::opset_version() > 15) {
|
if (OnnxNodeParser::opset_version() > 15) {
|
||||||
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX)));
|
||||||
} else {
|
} else {
|
||||||
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
|
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_ONNX_LOW_VERSION)));
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -170,13 +170,16 @@ STATUS OnnxModelParser::ConvertGraphInputs(const onnx::GraphProto &onnx_graph, c
|
||||||
<< static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
|
<< static_cast<onnx::TensorProto_DataType>(input_value.type().tensor_type().elem_type());
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto type_ptr = TypeIdToType(data_type);
|
|
||||||
std::vector<int64_t> shape_vector;
|
std::vector<int64_t> shape_vector;
|
||||||
auto onnx_shape = input_value.type().tensor_type().shape().dim();
|
auto onnx_shape = input_value.type().tensor_type().shape().dim();
|
||||||
std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
|
std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector),
|
||||||
[](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
|
[](const onnx::TensorShapeProto_Dimension &val) { return static_cast<int64_t>(val.dim_value()); });
|
||||||
std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
|
std::replace(shape_vector.begin(), shape_vector.end(), 0, -1);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
|
||||||
|
if (abstract_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
parameter->set_abstract(abstract_tensor);
|
parameter->set_abstract(abstract_tensor);
|
||||||
parameter->set_name(input_value.name());
|
parameter->set_name(input_value.name());
|
||||||
anf_nodes_map->emplace(input_value.name(), parameter);
|
anf_nodes_map->emplace(input_value.name(), parameter);
|
||||||
|
@ -490,17 +493,23 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const F
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
if (onnx_node.output_size() == 1) {
|
if (onnx_node.output_size() == 1) {
|
||||||
auto type_ptr = TypeIdToType(kNumberTypeFloat32);
|
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
std::vector<int64_t> shape_vector;
|
if (abstract_tensor == nullptr) {
|
||||||
cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
cnode->set_abstract(abstract_tensor);
|
||||||
anf_nodes_map->emplace(onnx_node.output(0), cnode);
|
anf_nodes_map->emplace(onnx_node.output(0), cnode);
|
||||||
} else {
|
} else {
|
||||||
AbstractBasePtrList abstract_list;
|
AbstractBasePtrList abstract_list;
|
||||||
int op_idx = 0;
|
int op_idx = 0;
|
||||||
for (const auto &output_name : onnx_node.output()) {
|
for (const auto &output_name : onnx_node.output()) {
|
||||||
std::vector<int64_t> shape_vector;
|
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
auto type_ptr = TypeIdToType(kNumberTypeFloat32);
|
if (abstract_tensor == nullptr) {
|
||||||
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
|
@ -687,7 +696,11 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto const_node = anf_graph->add_parameter();
|
auto const_node = anf_graph->add_parameter();
|
||||||
auto const_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>());
|
auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
|
||||||
|
if (const_abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
const_node->set_abstract(const_abstract);
|
const_node->set_abstract(const_abstract);
|
||||||
int *tensor_data = new (std::nothrow) int[1];
|
int *tensor_data = new (std::nothrow) int[1];
|
||||||
if (tensor_data == nullptr) {
|
if (tensor_data == nullptr) {
|
||||||
|
@ -834,9 +847,16 @@ STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::v
|
||||||
for (int i = 0; i < act_output_num; i++) {
|
for (int i = 0; i < act_output_num; i++) {
|
||||||
// tensor_array need as root while input
|
// tensor_array need as root while input
|
||||||
auto while_tensor_array_input = anf_root_graph->add_parameter();
|
auto while_tensor_array_input = anf_root_graph->add_parameter();
|
||||||
std::vector<int64_t> shape_vector;
|
auto tensor_info = CreateTensorInfo(nullptr, 0, {}, kObjectTypeTensorType);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kTensorType, shape_vector);
|
if (tensor_info == nullptr) {
|
||||||
auto tensor_info = std::make_shared<tensor::Tensor>(kObjectTypeTensorType, shape_vector);
|
MS_LOG(ERROR) << "Create tensor info failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto abstract_tensor = tensor_info->ToAbstract();
|
||||||
|
if (abstract_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
while_tensor_array_input->set_abstract(abstract_tensor);
|
while_tensor_array_input->set_abstract(abstract_tensor);
|
||||||
while_tensor_array_input->set_default_param(tensor_info);
|
while_tensor_array_input->set_default_param(tensor_info);
|
||||||
while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray");
|
while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray");
|
||||||
|
@ -975,7 +995,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
|
||||||
auto input_paramter = cond_graph->add_parameter();
|
auto input_paramter = cond_graph->add_parameter();
|
||||||
input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
|
input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter");
|
||||||
auto root_while_inputs = root_while_node->cast<CNodePtr>()->inputs();
|
auto root_while_inputs = root_while_node->cast<CNodePtr>()->inputs();
|
||||||
auto input_abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>());
|
auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32);
|
||||||
|
if (input_abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
input_paramter->set_abstract(input_abstract);
|
input_paramter->set_abstract(input_abstract);
|
||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
auto zero_parameter = CreateConstParamter(cond_graph, 0);
|
auto zero_parameter = CreateConstParamter(cond_graph, 0);
|
||||||
|
@ -987,7 +1011,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf
|
||||||
MS_LOG(ERROR) << "new cnode error";
|
MS_LOG(ERROR) << "new cnode error";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto less_abstract = std::make_shared<abstract::AbstractTensor>(kBool, std::vector<int64_t>());
|
auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool);
|
||||||
|
if (less_abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
less_cnode->set_abstract(less_abstract);
|
less_cnode->set_abstract(less_abstract);
|
||||||
less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
|
less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode");
|
||||||
}
|
}
|
||||||
|
@ -1020,12 +1048,11 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const
|
||||||
MS_LOG(ERROR) << "quant param type don't support.";
|
MS_LOG(ERROR) << "quant param type don't support.";
|
||||||
return RET_NOT_SUPPORT;
|
return RET_NOT_SUPPORT;
|
||||||
}
|
}
|
||||||
std::vector<int64_t> shape_vector;
|
auto parameter_node = res_graph_->add_parameter();
|
||||||
auto parameter_node = anf_root_graph_->add_parameter();
|
auto abstract_tensor = CreateTensorAbstract({}, type);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector);
|
|
||||||
if (abstract_tensor == nullptr) {
|
if (abstract_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "new abstract_tensor failed";
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
return RET_MEMORY_FAILED;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
parameter_node->set_abstract(abstract_tensor);
|
parameter_node->set_abstract(abstract_tensor);
|
||||||
parameter_node->set_name(name);
|
parameter_node->set_name(name);
|
||||||
|
@ -1051,9 +1078,12 @@ STATUS OnnxModelParser::BuildParameterNode(const ParameterPtr ¶meter_node, c
|
||||||
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
|
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(tensor.data_type());
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto type_ptr = TypeIdToType(data_type);
|
|
||||||
std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
|
std::vector<int64_t> shape_vector(tensor.dims().begin(), tensor.dims().end());
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type);
|
||||||
|
if (abstract_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
parameter_node->set_abstract(abstract_tensor);
|
parameter_node->set_abstract(abstract_tensor);
|
||||||
parameter_node->set_name(tensor.name());
|
parameter_node->set_name(tensor.name());
|
||||||
|
|
||||||
|
@ -1142,5 +1172,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type
|
||||||
}
|
}
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int OnnxModelParser::PostAdjust() { return 0; }
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -40,14 +40,17 @@ class OnnxModelParser : public ModelParser {
|
||||||
|
|
||||||
~OnnxModelParser() override = default;
|
~OnnxModelParser() override = default;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
|
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||||
const QuantType &quant_type) override;
|
const QuantType &quant_type) override;
|
||||||
|
|
||||||
|
int PostAdjust() override;
|
||||||
|
|
||||||
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type);
|
||||||
static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor,
|
static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor,
|
||||||
const tensor::TensorPtr ¶m_value_lite);
|
const tensor::TensorPtr ¶m_value_lite);
|
||||||
|
STATUS InitOriginModel(const std::string &model_file);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
STATUS InitOriginModel(const std::string &model_file);
|
|
||||||
STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr,
|
||||||
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
|
std::unordered_map<std::string, AnfNodePtr> *anf_nodes_map, std::vector<AnfNodePtr> *graph_inputs,
|
||||||
const std::string &root_node_name);
|
const std::string &root_node_name);
|
||||||
|
@ -94,7 +97,6 @@ class OnnxModelParser : public ModelParser {
|
||||||
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
|
std::unordered_map<std::string, AnfNodePtr> anf_nodes_map_;
|
||||||
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
|
std::unordered_map<std::string, std::unordered_map<std::string, AnfNodePtr> *> control_nodes_map_;
|
||||||
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
|
std::unordered_map<std::string, std::string> child_root_map_; // for nest control flow node
|
||||||
FuncGraphPtr anf_root_graph_ = nullptr;
|
|
||||||
};
|
};
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -418,18 +418,17 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
|
||||||
type = TensorFlowUtils::GetTFDataType(attr_value.type());
|
type = TensorFlowUtils::GetTFDataType(attr_value.type());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> shape;
|
std::vector<int64_t> shape;
|
||||||
if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) {
|
if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) {
|
||||||
auto &shape_attr = attr_value.shape();
|
auto &shape_attr = attr_value.shape();
|
||||||
for (int i = 0; i < shape_attr.dim_size(); ++i) {
|
for (int i = 0; i < shape_attr.dim_size(); ++i) {
|
||||||
shape.push_back(shape_attr.dim(i).size());
|
shape.push_back(shape_attr.dim(i).size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
|
||||||
|
|
||||||
if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
|
if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) {
|
||||||
MS_LOG(INFO) << "Found value attr, means it has default value";
|
MS_LOG(INFO) << "Found value attr, means it has default value";
|
||||||
auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape_vector);
|
auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "convert const tensor failed.";
|
MS_LOG(ERROR) << "convert const tensor failed.";
|
||||||
return status;
|
return status;
|
||||||
|
@ -438,10 +437,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa
|
||||||
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
|
graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type_ptr = TypeIdToType(type == kNumberTypeInt64 ? kNumberTypeInt32 : type);
|
type = (type == kNumberTypeInt64) ? kNumberTypeInt32 : type;
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
auto abstract_tensor = CreateTensorAbstract(shape, type);
|
||||||
if (abstract_tensor == nullptr) {
|
if (abstract_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "abstract_tensor is nullptr";
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
parameter->set_name(node.name());
|
parameter->set_name(node.name());
|
||||||
|
@ -474,51 +473,51 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts(
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
FuncGraphPtr paserTfFuction() { return nullptr; }
|
|
||||||
FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
|
int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile,
|
||||||
const QuantType &quantType) {
|
const QuantType &quantType) {
|
||||||
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
NotSupportOp::GetInstance()->set_fmk_type("TF");
|
||||||
auto status = ValidateFileStr(modelFile, ".pb");
|
auto status = ValidateFileStr(modelFile, ".pb");
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
|
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
|
tf_root_graph_ = std::make_unique<tensorflow::GraphDef>();
|
||||||
if (tf_root_graph_ == nullptr) {
|
if (tf_root_graph_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
|
MS_LOG(ERROR) << "tf_root_graph_ is nullptr";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get());
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
|
MS_LOG(ERROR) << "Open modelFile for TF converter failed!";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
anf_root_graph_ = std::make_shared<FuncGraph>();
|
res_graph_ = std::make_shared<FuncGraph>();
|
||||||
if (anf_root_graph_ == nullptr) {
|
if (res_graph_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "funGraphPtr is nullptr";
|
MS_LOG(ERROR) << "funGraphPtr is nullptr";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
anf_root_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||||
anf_root_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
|
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TF)));
|
||||||
|
|
||||||
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
||||||
auto &node_def = tf_root_graph_->node(i);
|
auto &node_def = tf_root_graph_->node(i);
|
||||||
tf_root_graph_nodes_[node_def.name()] = &node_def;
|
tf_root_graph_nodes_[node_def.name()] = &node_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
|
status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
bool success_flag = true;
|
bool success_flag = true;
|
||||||
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
for (int i = 0; i < tf_root_graph_->node_size(); i++) {
|
||||||
auto &node_def = tf_root_graph_->node(i);
|
auto &node_def = tf_root_graph_->node(i);
|
||||||
status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_);
|
status = ConvertOps(node_def, tf_root_graph_nodes_, res_graph_, &anf_root_node_map_);
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
success_flag = false;
|
success_flag = false;
|
||||||
|
@ -526,7 +525,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
|
||||||
}
|
}
|
||||||
if (!success_flag) {
|
if (!success_flag) {
|
||||||
MS_LOG(ERROR) << "Convert ops failed.";
|
MS_LOG(ERROR) << "Convert ops failed.";
|
||||||
return nullptr;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!nodes_with_null_input_.empty()) {
|
if (!nodes_with_null_input_.empty()) {
|
||||||
|
@ -534,7 +533,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Connect null inputs failed.";
|
MS_LOG(ERROR) << "Connect null inputs failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -542,17 +541,17 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Convert graph outputs failed.";
|
MS_LOG(ERROR) << "Convert graph outputs failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = ConvertSubgraph();
|
status = ConvertSubgraph();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Convert subgraph failed.";
|
MS_LOG(ERROR) << "Convert subgraph failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
return anf_root_graph_;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
|
STATUS TFModelParser::ConvertSubgraphInputs(std::map<std::string, const tensorflow::NodeDef *> *tf_sub_node_map,
|
||||||
|
@ -746,7 +745,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
|
||||||
MS_LOG(ERROR) << "while cond body size error";
|
MS_LOG(ERROR) << "while cond body size error";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
static auto root_func_manager = Manage(anf_root_graph_);
|
static auto root_func_manager = Manage(res_graph_);
|
||||||
|
|
||||||
for (auto &kv : first_func_map) {
|
for (auto &kv : first_func_map) {
|
||||||
auto control_flow_node = kv.first;
|
auto control_flow_node = kv.first;
|
||||||
|
@ -758,7 +757,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::map<CNodePtr, FuncGr
|
||||||
auto second_value_node = NewValueNode(second_sub_graph);
|
auto second_value_node = NewValueNode(second_sub_graph);
|
||||||
auto inputs = control_flow_node->inputs();
|
auto inputs = control_flow_node->inputs();
|
||||||
inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node});
|
inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node});
|
||||||
auto new_node = anf_root_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update
|
auto new_node = res_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update
|
||||||
if (new_node == nullptr) {
|
if (new_node == nullptr) {
|
||||||
MS_LOG(ERROR) << "new node failed";
|
MS_LOG(ERROR) << "new node failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
|
@ -812,43 +811,46 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C
|
||||||
if (output_size == 0) {
|
if (output_size == 0) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
} else if (output_size == 1) {
|
} else if (output_size == 1) {
|
||||||
auto type = kFloat32;
|
auto type = kNumberTypeFloat32;
|
||||||
std::vector<int64_t> shape_vector;
|
|
||||||
if (IsTensorListOp(anf_node)) {
|
if (IsTensorListOp(anf_node)) {
|
||||||
type = TypeIdToType(kObjectTypeTensorType);
|
type = kObjectTypeTensorType;
|
||||||
}
|
}
|
||||||
auto abstract = std::make_shared<abstract::AbstractTensor>(type, shape_vector);
|
auto abstract_tensor = CreateTensorAbstract({}, type);
|
||||||
if (abstract == nullptr) {
|
if (abstract_tensor == nullptr) {
|
||||||
MS_LOG(ERROR) << "create AbstractTensor failed";
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
anf_node->set_abstract(abstract);
|
anf_node->set_abstract(abstract_tensor);
|
||||||
anf_node_map->insert(std::pair(op.name(), anf_node));
|
anf_node_map->insert(std::pair(op.name(), anf_node));
|
||||||
} else {
|
} else {
|
||||||
AbstractBasePtrList abstractList;
|
AbstractBasePtrList abstract_list;
|
||||||
for (int output_idx = 0; output_idx < output_size; output_idx++) {
|
for (int output_idx = 0; output_idx < output_size; output_idx++) {
|
||||||
std::vector<int64_t> shape_vector;
|
auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
|
if (abstract_tensor == nullptr) {
|
||||||
auto tupleGetItemPrimPtr = std::make_shared<ops::TupleGetItem>();
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
if (tupleGetItemPrimPtr == nullptr) {
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr);
|
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||||
auto getItemValue = NewValueNode(MakeValue<int>(output_idx));
|
auto get_item_value = NewValueNode(MakeValue<int>(output_idx));
|
||||||
std::vector<AnfNodePtr> inputs{tupleGetItemPrim, anf_node, getItemValue};
|
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, anf_node, get_item_value};
|
||||||
CNodePtr getItemCNode = anf_graph->NewCNode(inputs);
|
CNodePtr get_item_cnode = anf_graph->NewCNode(inputs);
|
||||||
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
|
std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
|
||||||
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
|
auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
if (abstract == nullptr) {
|
if (get_item_abstract == nullptr) {
|
||||||
MS_LOG(ERROR) << "create AbstractTensor failed";
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
getItemCNode->set_abstract(abstract);
|
get_item_cnode->set_abstract(get_item_abstract);
|
||||||
getItemCNode->set_fullname_with_scope(output_item_name);
|
get_item_cnode->set_fullname_with_scope(output_item_name);
|
||||||
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode));
|
anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), get_item_cnode));
|
||||||
}
|
}
|
||||||
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstractList));
|
anf_node->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -1022,7 +1024,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_);
|
auto status = MakeAnfGraphOutputs(&output_nodes, res_graph_);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "make anf graph outputs node error";
|
MS_LOG(ERROR) << "make anf graph outputs node error";
|
||||||
return status;
|
return status;
|
||||||
|
@ -1070,5 +1072,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector<AnfNodePtr> *output_nodes,
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TFModelParser::PostAdjust() { return 0; }
|
||||||
} // namespace lite
|
} // namespace lite
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -36,9 +36,11 @@ namespace lite {
|
||||||
class TFModelParser : public ModelParser {
|
class TFModelParser : public ModelParser {
|
||||||
public:
|
public:
|
||||||
TFModelParser() = default;
|
TFModelParser() = default;
|
||||||
~TFModelParser() = default;
|
~TFModelParser() override = default;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
|
int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType);
|
||||||
|
|
||||||
|
int PostAdjust() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
|
static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info);
|
||||||
|
@ -87,7 +89,6 @@ class TFModelParser : public ModelParser {
|
||||||
|
|
||||||
STATUS ConnectNullInput();
|
STATUS ConnectNullInput();
|
||||||
|
|
||||||
FuncGraphPtr anf_root_graph_;
|
|
||||||
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
|
std::unique_ptr<tensorflow::GraphDef> tf_root_graph_; // tf root graph def
|
||||||
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
|
std::map<std::string, const tensorflow::NodeDef *> tf_root_graph_nodes_; // tf root graph node map
|
||||||
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
|
std::unordered_map<std::string, AnfNodePtr> anf_root_node_map_;
|
||||||
|
|
|
@ -43,46 +43,46 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
|
||||||
return tflite::UnPackModel(tflite_model_buf_);
|
return tflite::UnPackModel(tflite_model_buf_);
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file,
|
int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||||
const QuantType &quant_type) {
|
const QuantType &quant_type) {
|
||||||
// load graph
|
// load graph
|
||||||
tflite_model_ = ReadTfliteModel(model_file.c_str());
|
tflite_model_ = ReadTfliteModel(model_file.c_str());
|
||||||
if (tflite_model_ == nullptr) {
|
if (tflite_model_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "read tflite model failed";
|
MS_LOG(ERROR) << "read tflite model failed";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||||
return nullptr;
|
return RET_GRAPH_FILE_ERR;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tflite_model_->subgraphs.size() != 1) {
|
if (tflite_model_->subgraphs.size() != 1) {
|
||||||
MS_LOG(ERROR) << "read tflite model subgraphs failed";
|
MS_LOG(ERROR) << "read tflite model subgraphs failed";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR);
|
||||||
return nullptr;
|
return RET_GRAPH_FILE_ERR;
|
||||||
}
|
}
|
||||||
func_graph_ = std::make_shared<FuncGraph>();
|
res_graph_ = std::make_shared<FuncGraph>();
|
||||||
func_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));
|
res_graph_->set_attr("fmk", MakeValue(static_cast<int>(converter::FmkType_TFLITE)));
|
||||||
|
|
||||||
auto status = ConvertGraphInputs();
|
auto status = ConvertGraphInputs();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Convert graph inputs failed.";
|
MS_LOG(ERROR) << "Convert graph inputs failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = ConvertOps();
|
status = ConvertOps();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Convert ops failed.";
|
MS_LOG(ERROR) << "Convert ops failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
status = ConvertGraphOutputs();
|
status = ConvertGraphOutputs();
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Convert graph outputs failed.";
|
MS_LOG(ERROR) << "Convert graph outputs failed.";
|
||||||
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
||||||
return nullptr;
|
return status;
|
||||||
}
|
}
|
||||||
func_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
res_graph_->set_attr("graph_name", MakeValue("main_graph"));
|
||||||
return func_graph_;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
|
std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) {
|
||||||
|
@ -158,7 +158,7 @@ STATUS TfliteModelParser::ConvertOps() {
|
||||||
} else {
|
} else {
|
||||||
tensor_name = GetTensorName(i, tflite_op_type, op_name);
|
tensor_name = GetTensorName(i, tflite_op_type, op_name);
|
||||||
}
|
}
|
||||||
auto parameter = func_graph_->add_parameter();
|
auto parameter = res_graph_->add_parameter();
|
||||||
status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name);
|
status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name);
|
||||||
if (status != RET_OK) {
|
if (status != RET_OK) {
|
||||||
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed.";
|
||||||
|
@ -168,7 +168,7 @@ STATUS TfliteModelParser::ConvertOps() {
|
||||||
op_inputs.emplace_back(parameter);
|
op_inputs.emplace_back(parameter);
|
||||||
nodes_.insert(std::pair(input_idx, parameter));
|
nodes_.insert(std::pair(input_idx, parameter));
|
||||||
}
|
}
|
||||||
auto new_cnode = func_graph_->NewCNode(op_inputs);
|
auto new_cnode = res_graph_->NewCNode(op_inputs);
|
||||||
new_cnode->set_fullname_with_scope(op_name);
|
new_cnode->set_fullname_with_scope(op_name);
|
||||||
|
|
||||||
// parse outputs
|
// parse outputs
|
||||||
|
@ -284,13 +284,16 @@ STATUS TfliteModelParser::ConvertGraphInputs() {
|
||||||
if (tflite_graph_input < 0) {
|
if (tflite_graph_input < 0) {
|
||||||
tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size();
|
tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size();
|
||||||
}
|
}
|
||||||
auto parameter = func_graph_->add_parameter();
|
auto parameter = res_graph_->add_parameter();
|
||||||
const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input);
|
const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input);
|
||||||
std::vector<int64_t> shape_vector;
|
std::vector<int64_t> shape_vector;
|
||||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
if (abstract_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
parameter->set_abstract(abstract_tensor);
|
parameter->set_abstract(abstract_tensor);
|
||||||
parameter->set_name("graph_input-" + std::to_string(tflite_graph_input));
|
parameter->set_name("graph_input-" + std::to_string(tflite_graph_input));
|
||||||
nodes_.insert(std::pair(tflite_graph_input, parameter));
|
nodes_.insert(std::pair(tflite_graph_input, parameter));
|
||||||
|
@ -318,7 +321,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
||||||
}
|
}
|
||||||
make_tuple_inputs.emplace_back(cnode);
|
make_tuple_inputs.emplace_back(cnode);
|
||||||
}
|
}
|
||||||
auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs);
|
auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs);
|
||||||
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
make_tuple_cnode->set_fullname_with_scope("return tuple");
|
||||||
|
|
||||||
std::vector<AnfNodePtr> op_inputs;
|
std::vector<AnfNodePtr> op_inputs;
|
||||||
|
@ -330,9 +333,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
||||||
auto value_node = NewValueNode(return_prim_ptr);
|
auto value_node = NewValueNode(return_prim_ptr);
|
||||||
op_inputs.emplace_back(value_node);
|
op_inputs.emplace_back(value_node);
|
||||||
op_inputs.emplace_back(make_tuple_cnode);
|
op_inputs.emplace_back(make_tuple_cnode);
|
||||||
auto cnode = func_graph_->NewCNode(op_inputs);
|
auto cnode = res_graph_->NewCNode(op_inputs);
|
||||||
cnode->set_fullname_with_scope("Return");
|
cnode->set_fullname_with_scope("Return");
|
||||||
func_graph_->set_return(cnode);
|
res_graph_->set_return(cnode);
|
||||||
} else {
|
} else {
|
||||||
auto returnPrim = std::make_shared<ops::Return>();
|
auto returnPrim = std::make_shared<ops::Return>();
|
||||||
if (returnPrim == nullptr) {
|
if (returnPrim == nullptr) {
|
||||||
|
@ -350,9 +353,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() {
|
||||||
return RET_NOT_FIND_OP;
|
return RET_NOT_FIND_OP;
|
||||||
}
|
}
|
||||||
op_inputs.emplace_back(cnode);
|
op_inputs.emplace_back(cnode);
|
||||||
auto returnCnode = func_graph_->NewCNode(op_inputs);
|
auto returnCnode = res_graph_->NewCNode(op_inputs);
|
||||||
returnCnode->set_fullname_with_scope("Return");
|
returnCnode->set_fullname_with_scope("Return");
|
||||||
func_graph_->set_return(returnCnode);
|
res_graph_->set_return(returnCnode);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -436,8 +439,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
|
||||||
std::vector<int64_t> shape_vector;
|
std::vector<int64_t> shape_vector;
|
||||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
|
||||||
dst_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
if (abstract_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
dst_cnode->set_abstract(abstract_tensor);
|
||||||
nodes_.insert(std::pair(op->outputs.front(), dst_cnode));
|
nodes_.insert(std::pair(op->outputs.front(), dst_cnode));
|
||||||
} else {
|
} else {
|
||||||
AbstractBasePtrList abstract_list;
|
AbstractBasePtrList abstract_list;
|
||||||
|
@ -450,8 +457,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
|
||||||
std::vector<int64_t> shape_vector;
|
std::vector<int64_t> shape_vector;
|
||||||
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
(void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector),
|
||||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||||
auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type));
|
auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type));
|
||||||
abstract_list.emplace_back(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
|
if (abstract_tensor == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
abstract_list.emplace_back(abstract_tensor);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new TupleGetItem failed";
|
MS_LOG(ERROR) << "new TupleGetItem failed";
|
||||||
|
@ -460,7 +471,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
|
||||||
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr);
|
||||||
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
|
auto get_item_value = NewValueNode(MakeValue<int>(op_idx));
|
||||||
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value};
|
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value};
|
||||||
CNodePtr get_item_cnode = func_graph_->NewCNode(inputs);
|
CNodePtr get_item_cnode = res_graph_->NewCNode(inputs);
|
||||||
get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
|
get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx));
|
||||||
nodes_.insert(std::pair(output_idx, get_item_cnode));
|
nodes_.insert(std::pair(output_idx, get_item_cnode));
|
||||||
op_idx++;
|
op_idx++;
|
||||||
|
@ -469,4 +480,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int TfliteModelParser::PostAdjust() { return 0; }
|
||||||
} // namespace mindspore::lite
|
} // namespace mindspore::lite
|
||||||
|
|
|
@ -32,13 +32,14 @@ class TfliteModelParser : public ModelParser {
|
||||||
|
|
||||||
~TfliteModelParser() override = default;
|
~TfliteModelParser() override = default;
|
||||||
|
|
||||||
FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file,
|
int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file,
|
||||||
const QuantType &quant_type) override;
|
const QuantType &quant_type) override;
|
||||||
|
|
||||||
|
int PostAdjust() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<int, AnfNodePtr> nodes_;
|
std::unordered_map<int, AnfNodePtr> nodes_;
|
||||||
std::unique_ptr<tflite::ModelT> tflite_model_;
|
std::unique_ptr<tflite::ModelT> tflite_model_;
|
||||||
FuncGraphPtr func_graph_;
|
|
||||||
char *tflite_model_buf_ = nullptr;
|
char *tflite_model_buf_ = nullptr;
|
||||||
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
|
std::unique_ptr<tflite::ModelT> ReadTfliteModel(const char *model_path);
|
||||||
STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr ¶meter,
|
STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr ¶meter,
|
||||||
|
|
|
@ -399,6 +399,24 @@ int CheckIfCNodeIsNull(const CNodePtr &node) {
|
||||||
return lite::RET_OK;
|
return lite::RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int CheckIfParameterIsNull(const ParameterPtr &node) {
|
||||||
|
if (node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The Parameter is null.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return lite::RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int CheckIfValueNodeIsNull(const ValueNodePtr &node) {
|
||||||
|
if (node == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "The ValueNode is null.";
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return lite::RET_NULL_PTR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int CheckIfVarIsNull(const VarPtr &var) {
|
int CheckIfVarIsNull(const VarPtr &var) {
|
||||||
if (var == nullptr) {
|
if (var == nullptr) {
|
||||||
MS_LOG(ERROR) << "The Var is null.";
|
MS_LOG(ERROR) << "The Var is null.";
|
||||||
|
|
|
@ -57,6 +57,10 @@ int CheckIfAnfNodeIsNull(const AnfNodePtr &node);
|
||||||
|
|
||||||
int CheckIfCNodeIsNull(const CNodePtr &node);
|
int CheckIfCNodeIsNull(const CNodePtr &node);
|
||||||
|
|
||||||
|
int CheckIfParameterIsNull(const ParameterPtr &node);
|
||||||
|
|
||||||
|
int CheckIfValueNodeIsNull(const ValueNodePtr &node);
|
||||||
|
|
||||||
int CheckIfVarIsNull(const VarPtr &var);
|
int CheckIfVarIsNull(const VarPtr &var);
|
||||||
|
|
||||||
int CheckInputSize(const CNodePtr &node, int size);
|
int CheckInputSize(const CNodePtr &node, int size);
|
||||||
|
|
|
@ -0,0 +1,294 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "tools/optimizer/fusion/mul_add_fusion.h"
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/fusion/mul_fusion.h"
|
||||||
|
#include "ops/fusion/add_fusion.h"
|
||||||
|
#include "ops/fusion/scale_fusion.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore::opt {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kMulInputsLength = 3;
|
||||||
|
constexpr size_t kAddInputsLength = 3;
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef MulAddFusion::DefinePattern() const {
|
||||||
|
auto mul_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
|
||||||
|
auto add_var = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
|
||||||
|
return VectorRef({add_var, mul_var});
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MulAddFusion::ScaleInputShapeValid() const {
|
||||||
|
MS_ASSERT(scale_tensor_ != nullptr);
|
||||||
|
MS_ASSERT(bias_tensor_ != nullptr);
|
||||||
|
auto scale_shape = scale_tensor_->shape_c();
|
||||||
|
auto offset_shape = bias_tensor_->shape_c();
|
||||||
|
if (mul_input_shape_.size() < scale_shape.size() || scale_shape.size() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_t rank_diff = mul_input_shape_.size() - scale_shape.size();
|
||||||
|
for (size_t i = 0; i < scale_shape.size(); ++i) {
|
||||||
|
if (mul_input_shape_[i + rank_diff] != scale_shape[i]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (scale_shape != offset_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MulAddFusion::CheckMulNode(const FuncGraphPtr &func_graph) const {
|
||||||
|
MS_ASSERT(func_graph != nullptr);
|
||||||
|
if (mul_anode_ == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (IsMultiOutputTensors(func_graph, mul_anode_)) {
|
||||||
|
MS_LOG(DEBUG) << "Mul op has multi-output";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto mul_node = mul_anode_->cast<CNodePtr>();
|
||||||
|
if (!CheckPrimitiveType(mul_node, prim::kPrimMulFusion)) {
|
||||||
|
MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto mul_primitive = GetValueNode<std::shared_ptr<ops::MulFusion>>(mul_node->input(0));
|
||||||
|
MS_ASSERT(mul_primitive != nullptr);
|
||||||
|
auto mul_act_type = mul_primitive->get_activation_type();
|
||||||
|
if (mul_act_type != ActivationType::NO_ACTIVATION) {
|
||||||
|
MS_LOG(DEBUG) << "Only support mul node with no activation";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (CheckIfCNodeIsNull(mul_node) != lite::RET_OK || CheckInputSize(mul_node, kMulInputsLength) != lite::RET_OK) {
|
||||||
|
MS_LOG(DEBUG) << "Mul op is null or has error input size";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// find mul's const input and mul input
|
||||||
|
AnfNodePtr mul_pre_input_node = nullptr;
|
||||||
|
AnfNodePtr mul_pre_const_node = nullptr;
|
||||||
|
auto mul_pre_node_1 = mul_node->input(1);
|
||||||
|
if (CheckIfAnfNodeIsNull(mul_pre_node_1) != lite::RET_OK) {
|
||||||
|
MS_LOG(DEBUG) << "Pre-node of mul op is nullptr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto mul_pre_node_2 = mul_node->input(2);
|
||||||
|
if (CheckIfAnfNodeIsNull(mul_pre_node_2) != lite::RET_OK) {
|
||||||
|
MS_LOG(DEBUG) << "Pre-node of mul op is nullptr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (utils::isa<CNodePtr>(mul_pre_node_1) && !utils::isa<CNodePtr>(mul_pre_node_2)) {
|
||||||
|
mul_pre_input_node = mul_pre_node_1;
|
||||||
|
mul_pre_const_node = mul_pre_node_2;
|
||||||
|
} else if (!utils::isa<CNodePtr>(mul_pre_node_1) && utils::isa<CNodePtr>(mul_pre_node_2)) {
|
||||||
|
mul_pre_input_node = mul_pre_node_1;
|
||||||
|
mul_pre_const_node = mul_pre_node_2;
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Mul op should has a cnode input and a const input";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// check mul's const input
|
||||||
|
tensor::TensorPtr mul_tensor = nullptr;
|
||||||
|
if (utils::isa<ParameterPtr>(mul_pre_const_node)) {
|
||||||
|
auto mul_bias_node = mul_pre_const_node->cast<ParameterPtr>();
|
||||||
|
MS_ASSERT(mul_bias_node != nullptr);
|
||||||
|
if (!mul_bias_node->has_default()) {
|
||||||
|
MS_LOG(DEBUG) << "Const input of mul op should has data";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mul_tensor = mul_bias_node->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
} else if (utils::isa<ValueNodePtr>(mul_pre_const_node)) {
|
||||||
|
auto mul_bias_node = mul_pre_const_node->cast<ValueNodePtr>();
|
||||||
|
MS_ASSERT(mul_bias_node != nullptr);
|
||||||
|
if (mul_bias_node->value() == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Const input of mul op should has data";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mul_tensor = mul_bias_node->value()->cast<tensor::TensorPtr>();
|
||||||
|
} else {
|
||||||
|
MS_ASSERT(false);
|
||||||
|
}
|
||||||
|
if (mul_tensor == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Const input of add op should has data";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mul_input_anode_ = mul_pre_input_node;
|
||||||
|
mul_const_anode_ = mul_pre_const_node;
|
||||||
|
scale_tensor_ = mul_tensor;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MulAddFusion::CheckAddNode() const {
|
||||||
|
if (add_anode_ == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto add_cnode = add_anode_->cast<CNodePtr>();
|
||||||
|
if (CheckIfCNodeIsNull(add_cnode) != lite::RET_OK || CheckInputSize(add_cnode, kAddInputsLength) != lite::RET_OK) {
|
||||||
|
MS_LOG(DEBUG) << "Add op is null or has error input size";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!CheckPrimitiveType(add_cnode, prim::kPrimAddFusion)) {
|
||||||
|
MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto add_primitive = GetValueNode<std::shared_ptr<ops::AddFusion>>(add_cnode->input(0));
|
||||||
|
MS_ASSERT(add_primitive != nullptr);
|
||||||
|
auto add_act_type = add_primitive->get_activation_type();
|
||||||
|
if (add_act_type != ActivationType::RELU && add_act_type != ActivationType::RELU6 &&
|
||||||
|
add_act_type != ActivationType::NO_ACTIVATION) {
|
||||||
|
MS_LOG(DEBUG) << "Only support add node with relu or relu6 or no activation";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
scale_act_type_ = add_act_type;
|
||||||
|
// find add's const input and mul input
|
||||||
|
AnfNodePtr add_pre_input_node = nullptr;
|
||||||
|
AnfNodePtr add_pre_const_node = nullptr;
|
||||||
|
auto add_pre_node_1 = add_cnode->input(1);
|
||||||
|
if (CheckIfAnfNodeIsNull(add_pre_node_1) != lite::RET_OK) {
|
||||||
|
MS_LOG(DEBUG) << "Pre-node of add op is nullptr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto add_pre_node_2 = add_cnode->input(2);
|
||||||
|
if (CheckIfAnfNodeIsNull(add_pre_node_2) != lite::RET_OK) {
|
||||||
|
MS_LOG(DEBUG) << "Pre-node of add op is nullptr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (utils::isa<CNodePtr>(add_pre_node_1) && !utils::isa<CNodePtr>(add_pre_node_2)) {
|
||||||
|
add_pre_input_node = add_pre_node_1;
|
||||||
|
add_pre_const_node = add_pre_node_2;
|
||||||
|
} else if (!utils::isa<CNodePtr>(add_pre_node_1) && utils::isa<CNodePtr>(add_pre_node_2)) {
|
||||||
|
add_pre_input_node = add_pre_node_2;
|
||||||
|
add_pre_const_node = add_pre_node_1;
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "Add op should has a cnode input and a const input";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// check add's const input
|
||||||
|
tensor::TensorPtr add_tensor = nullptr;
|
||||||
|
if (utils::isa<ParameterPtr>(add_pre_const_node)) {
|
||||||
|
auto add_bias_node = add_pre_const_node->cast<ParameterPtr>();
|
||||||
|
MS_ASSERT(add_bias_node != nullptr);
|
||||||
|
if (!add_bias_node->has_default()) {
|
||||||
|
MS_LOG(DEBUG) << "Const input of add op should has data";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
add_tensor = add_bias_node->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
} else if (utils::isa<ValueNodePtr>(add_pre_const_node)) {
|
||||||
|
auto add_bias_node = add_pre_const_node->cast<ValueNodePtr>();
|
||||||
|
MS_ASSERT(add_bias_node != nullptr);
|
||||||
|
if (add_bias_node->value() == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Const input of add op should has data";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
add_tensor = add_bias_node->value()->cast<tensor::TensorPtr>();
|
||||||
|
} else {
|
||||||
|
MS_ASSERT(false);
|
||||||
|
}
|
||||||
|
if (add_tensor == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Const input of add op should has data";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mul_anode_ = add_pre_input_node;
|
||||||
|
add_const_anode_ = add_pre_const_node;
|
||||||
|
bias_tensor_ = add_tensor;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MulAddFusion::GetMulInputShape() const {
|
||||||
|
MS_ASSERT(mul_input_anode_ != nullptr);
|
||||||
|
ShapeVector mul_input_shape;
|
||||||
|
AbstractBasePtr mul_input_abstract = nullptr;
|
||||||
|
if (utils::isa<ParameterPtr>(mul_input_anode_)) {
|
||||||
|
auto mul_input_node = mul_input_anode_->cast<ParameterPtr>();
|
||||||
|
MS_ASSERT(mul_bias_node != nullptr);
|
||||||
|
mul_input_abstract = mul_input_node->abstract();
|
||||||
|
} else if (utils::isa<ValueNodePtr>(mul_input_anode_)) {
|
||||||
|
auto mul_input_node = mul_input_anode_->cast<ValueNodePtr>();
|
||||||
|
MS_ASSERT(mul_input_node != nullptr);
|
||||||
|
mul_input_abstract = mul_input_node->abstract();
|
||||||
|
} else if (utils::isa<CNodePtr>(mul_input_anode_)) {
|
||||||
|
auto mul_input_node = mul_input_anode_->cast<CNodePtr>();
|
||||||
|
MS_ASSERT(mul_input_node != nullptr);
|
||||||
|
mul_input_abstract = mul_input_node->abstract();
|
||||||
|
} else {
|
||||||
|
MS_ASSERT(false);
|
||||||
|
}
|
||||||
|
if (mul_input_abstract == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Mul input node has no abstract";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!utils::isa<abstract::AbstractTensorPtr>(mul_input_abstract)) {
|
||||||
|
MS_LOG(DEBUG) << "Abstract of mul input node should be AbstractTensor";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(mul_input_abstract);
|
||||||
|
MS_ASSERT(abstract_tensor != nullptr);
|
||||||
|
MS_ASSERT(abstract_tensor->BuildShape() != nullptr);
|
||||||
|
if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) {
|
||||||
|
MS_LOG(DEBUG) << "BuildShape of abstract of mul input node should be ShapePtr";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
mul_input_shape_ = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
|
MS_ASSERT(func_graph != nullptr);
|
||||||
|
MS_ASSERT(node != nullptr);
|
||||||
|
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||||
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
add_anode_ = node;
|
||||||
|
if (!CheckAddNode()) {
|
||||||
|
MS_LOG(DEBUG) << "Add op is not suit for mul-add-fusion: " << node->fullname_with_scope();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
MS_ASSERT(mul_anode_ != nullptr);
|
||||||
|
MS_ASSERT(bias_tensor_ != nullptr);
|
||||||
|
MS_ASSERT(add_const_anode_ != nullptr);
|
||||||
|
if (!CheckMulNode(func_graph)) {
|
||||||
|
MS_LOG(DEBUG) << "Mul op is not suit for mul-add-fusion: " << mul_anode_->fullname_with_scope();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
MS_ASSERT(mul_input_anode_ != nullptr);
|
||||||
|
MS_ASSERT(scale_tensor_ != nullptr);
|
||||||
|
MS_ASSERT(mul_const_anode_ != nullptr);
|
||||||
|
if (!GetMulInputShape()) {
|
||||||
|
MS_LOG(DEBUG) << "Get input shape of mul op failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// scale requires scale shape tail sub of input shape, scale shape same as bias shape
|
||||||
|
if (!ScaleInputShapeValid()) {
|
||||||
|
MS_LOG(DEBUG) << "Check input shape, scale shape and bias shape failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
// create scale primitive
|
||||||
|
auto scale_primitive = new (std::nothrow) mindspore::ops::ScaleFusion();
|
||||||
|
if (scale_primitive == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "new scale primitive failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
scale_primitive->set_activation_type(scale_act_type_);
|
||||||
|
scale_primitive->set_axis(0 - bias_tensor_->shape_c().size());
|
||||||
|
// create scale op
|
||||||
|
auto scale_node = func_graph->NewCNode(std::shared_ptr<ops::PrimitiveC>(scale_primitive),
|
||||||
|
{mul_input_anode_, mul_const_anode_, add_const_anode_});
|
||||||
|
return scale_node;
|
||||||
|
}
|
||||||
|
} // namespace mindspore::opt
|
|
@ -0,0 +1,53 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class MulAddFusion : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit MulAddFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion")
|
||||||
|
: PatternProcessPass(name, multigraph) {}
|
||||||
|
~MulAddFusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool CheckMulNode(const FuncGraphPtr &func_graph) const;
|
||||||
|
bool CheckAddNode() const;
|
||||||
|
bool GetMulInputShape() const;
|
||||||
|
bool ScaleInputShapeValid() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable AnfNodePtr mul_anode_ = nullptr;
|
||||||
|
mutable AnfNodePtr mul_input_anode_ = nullptr;
|
||||||
|
mutable AnfNodePtr mul_const_anode_ = nullptr;
|
||||||
|
mutable ShapeVector mul_input_shape_;
|
||||||
|
mutable AnfNodePtr add_anode_ = nullptr;
|
||||||
|
mutable AnfNodePtr add_const_anode_ = nullptr;
|
||||||
|
mutable tensor::TensorPtr scale_tensor_ = nullptr;
|
||||||
|
mutable tensor::TensorPtr bias_tensor_ = nullptr;
|
||||||
|
mutable ActivationType scale_act_type_ = ActivationType::NO_ACTIVATION;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
|
|
@ -256,11 +256,12 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun
|
||||||
auto parameter = func_graph->add_parameter();
|
auto parameter = func_graph->add_parameter();
|
||||||
parameter->set_name(name);
|
parameter->set_name(name);
|
||||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type), shape_vector);
|
auto abstract = lite::CreateTensorAbstract(shape_vector, type);
|
||||||
if (abstract_tensor == nullptr) {
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
parameter->set_abstract(abstract_tensor);
|
parameter->set_abstract(abstract);
|
||||||
|
|
||||||
auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector);
|
auto gate_weight_default = std::make_shared<tensor::Tensor>(type, shape_vector);
|
||||||
if (gate_weight_default == nullptr) {
|
if (gate_weight_default == nullptr) {
|
||||||
|
|
|
@ -502,13 +502,12 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value});
|
CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value});
|
||||||
std::vector<int64_t> shape_vector;
|
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
|
if (abstract == nullptr) {
|
||||||
if (abstract_tensor == nullptr) {
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
MS_LOG(ERROR) << "create abstract_tensor failed";
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
get_item_cnode->set_abstract(abstract_tensor);
|
get_item_cnode->set_abstract(abstract);
|
||||||
get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
|
get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" +
|
||||||
std::to_string(item_index));
|
std::to_string(item_index));
|
||||||
return get_item_cnode;
|
return get_item_cnode;
|
||||||
|
@ -581,13 +580,12 @@ STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int o
|
||||||
MS_ASSERT(cnode != nullptr);
|
MS_ASSERT(cnode != nullptr);
|
||||||
AbstractBasePtrList abstract_list;
|
AbstractBasePtrList abstract_list;
|
||||||
for (int i = 0; i < output_num; ++i) {
|
for (int i = 0; i < output_num; ++i) {
|
||||||
std::vector<int64_t> shape_vector;
|
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
|
if (abstract == nullptr) {
|
||||||
if (abstract_tensor == nullptr) {
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
MS_LOG(ERROR) << "create abstract_tensor failed";
|
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
abstract_list.emplace_back(abstract_tensor);
|
abstract_list.emplace_back(abstract);
|
||||||
}
|
}
|
||||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||||
if (abstract_tuple == nullptr) {
|
if (abstract_tuple == nullptr) {
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "ops/return.h"
|
#include "ops/return.h"
|
||||||
#include "ops/tuple_get_item.h"
|
#include "ops/tuple_get_item.h"
|
||||||
#include "tools/converter/ops/while.h"
|
#include "tools/converter/ops/while.h"
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
mindspore::ValueNodePtr GetWhileAnfPrim() {
|
mindspore::ValueNodePtr GetWhileAnfPrim() {
|
||||||
|
@ -207,9 +208,13 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
|
||||||
auto node_users = manager->node_users()[node];
|
auto node_users = manager->node_users()[node];
|
||||||
for (auto &node_user : node_users) {
|
for (auto &node_user : node_users) {
|
||||||
// new getitem
|
// new getitem
|
||||||
AbstractBasePtrList abstractList;
|
AbstractBasePtrList abstract_list;
|
||||||
std::vector<int64_t> shape_vector;
|
auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
abstractList.emplace_back(std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector));
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
abstract_list.emplace_back(abstract);
|
||||||
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
auto tuple_get_item_prim_ptr = std::make_shared<ops::TupleGetItem>();
|
||||||
if (tuple_get_item_prim_ptr == nullptr) {
|
if (tuple_get_item_prim_ptr == nullptr) {
|
||||||
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr";
|
||||||
|
@ -225,12 +230,12 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() {
|
||||||
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue};
|
std::vector<AnfNodePtr> inputs{tuple_get_item_prim, while_node_, getItemValue};
|
||||||
CNodePtr get_item_node = fg_->NewCNode(inputs);
|
CNodePtr get_item_node = fg_->NewCNode(inputs);
|
||||||
std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
|
std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx);
|
||||||
auto abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shape_vector);
|
auto get_item_node_abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32);
|
||||||
if (abstract == nullptr) {
|
if (get_item_node_abstract == nullptr) {
|
||||||
MS_LOG(ERROR) << "create AbstractTensor failed";
|
MS_LOG(ERROR) << "Create get_item_node_abstract failed";
|
||||||
return RET_NULL_PTR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
get_item_node->set_abstract(abstract);
|
get_item_node->set_abstract(get_item_node_abstract);
|
||||||
get_item_node->set_fullname_with_scope(output_item_name);
|
get_item_node->set_fullname_with_scope(output_item_name);
|
||||||
// set
|
// set
|
||||||
if (fg_->nodes().contains(node_user.first)) {
|
if (fg_->nodes().contains(node_user.first)) {
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "src/tensor.h"
|
#include "src/tensor.h"
|
||||||
#include "tools/converter/quantizer/quant_cast.h"
|
#include "tools/converter/quantizer/quant_cast.h"
|
||||||
#include "src/common/log_adapter.h"
|
#include "src/common/log_adapter.h"
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
#include "securec/include/securec.h"
|
#include "securec/include/securec.h"
|
||||||
|
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
|
@ -101,13 +102,16 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||||
auto type_ptr = TypeIdToType(type_id);
|
|
||||||
auto shape = weight_value->shape();
|
auto shape = weight_value->shape();
|
||||||
std::vector<int64_t> shape_vector;
|
std::vector<int64_t> shape_vector;
|
||||||
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
||||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||||
weight_node->set_abstract(abstract_tensor);
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
weight_node->set_abstract(abstract);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include "tools/common/node_util.h"
|
#include "tools/common/node_util.h"
|
||||||
#include "tools/common/tensor_util.h"
|
#include "tools/common/tensor_util.h"
|
||||||
#include "src/common/common.h"
|
#include "src/common/common.h"
|
||||||
|
#include "src/common/tensor_util.h"
|
||||||
#include "src/ops/populate/populate_register.h"
|
#include "src/ops/populate/populate_register.h"
|
||||||
#include "src/ops/ops_utils.h"
|
#include "src/ops/ops_utils.h"
|
||||||
#include "src/runtime/infer_manager.h"
|
#include "src/runtime/infer_manager.h"
|
||||||
|
@ -28,19 +29,6 @@
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
namespace {
|
namespace {
|
||||||
constexpr size_t INITIAL_SIZE = 1024;
|
constexpr size_t INITIAL_SIZE = 1024;
|
||||||
tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) {
|
|
||||||
std::vector<int> shape(tensor->shape());
|
|
||||||
std::vector<int64_t> shape_vector;
|
|
||||||
std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
|
|
||||||
[](const int32_t &value) { return static_cast<int64_t>(value); });
|
|
||||||
auto tensor_info = std::make_shared<tensor::Tensor>(tensor->data_type(), shape_vector);
|
|
||||||
if (tensor_info == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new tensor::Tensor failed";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
return tensor_info;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsSpecialType(const CNodePtr &cnode) {
|
bool IsSpecialType(const CNodePtr &cnode) {
|
||||||
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
|
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) ||
|
||||||
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) ||
|
CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) ||
|
||||||
|
@ -75,21 +63,14 @@ STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) {
|
abstract::AbstractBasePtr InferShapePass::ConvertLiteTensorToAbstract(lite::Tensor *tensor) {
|
||||||
MS_ASSERT(nullptr != tensor);
|
MS_ASSERT(nullptr != tensor);
|
||||||
std::vector<int> shape(tensor->shape());
|
auto shape = tensor->shape();
|
||||||
auto type_id = static_cast<TypeId>(tensor->data_type());
|
auto type_id = static_cast<TypeId>(tensor->data_type());
|
||||||
auto type_ptr = TypeIdToType(type_id);
|
|
||||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||||
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, type_id);
|
||||||
if (new_abstract == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "new AbstractTensor failed";
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto tensor_info = NewTensorInfo(tensor);
|
|
||||||
if (tensor_info == nullptr) {
|
if (tensor_info == nullptr) {
|
||||||
MS_LOG(ERROR) << "new tensor::Tensor failed";
|
MS_LOG(DEBUG) << "Create tensor info failed";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,8 +93,12 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
new_abstract->set_value(tensor_info);
|
auto abstract = tensor_info->ToAbstract();
|
||||||
return new_abstract;
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Create tensor abstarct failed";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return abstract;
|
||||||
}
|
}
|
||||||
|
|
||||||
STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) {
|
STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) {
|
||||||
|
@ -143,8 +128,6 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) {
|
||||||
std::vector<int32_t> shape;
|
std::vector<int32_t> shape;
|
||||||
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
|
(void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
|
||||||
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
[](const int64_t &value) { return static_cast<int32_t>(value); });
|
||||||
|
|
||||||
auto new_abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
|
||||||
auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector);
|
auto new_tensor_info = std::make_shared<tensor::Tensor>(type_ptr->type_id(), shape_vector);
|
||||||
if (parameter->has_default()) {
|
if (parameter->has_default()) {
|
||||||
auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param());
|
auto old_tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(parameter->default_param());
|
||||||
|
@ -155,7 +138,11 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
new_abstract->set_value(new_tensor_info);
|
auto new_abstract = new_tensor_info->ToAbstract();
|
||||||
|
if (new_abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
parameter->set_abstract(new_abstract);
|
parameter->set_abstract(new_abstract);
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -304,7 +291,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
|
||||||
}
|
}
|
||||||
if (output_tensors.size() == 1) {
|
if (output_tensors.size() == 1) {
|
||||||
auto tensor = output_tensors.front();
|
auto tensor = output_tensors.front();
|
||||||
auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor);
|
auto new_abstract = ConvertLiteTensorToAbstract(tensor);
|
||||||
if (new_abstract == nullptr) {
|
if (new_abstract == nullptr) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -313,7 +300,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector<lite::Tensor *> &outpu
|
||||||
AbstractBasePtrList abstract_list;
|
AbstractBasePtrList abstract_list;
|
||||||
for (size_t i = 0; i < output_tensors.size(); i++) {
|
for (size_t i = 0; i < output_tensors.size(); i++) {
|
||||||
auto tensor = output_tensors.front();
|
auto tensor = output_tensors.front();
|
||||||
auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor);
|
auto new_abstract = ConvertLiteTensorToAbstract(tensor);
|
||||||
if (new_abstract == nullptr) {
|
if (new_abstract == nullptr) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ class InferShapePass : public Pass {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void FreeTensors(std::vector<lite::Tensor *> *tensors);
|
void FreeTensors(std::vector<lite::Tensor *> *tensors);
|
||||||
abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor);
|
abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor);
|
||||||
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors);
|
STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *input_tensors);
|
||||||
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors);
|
STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *output_tensors);
|
||||||
STATUS SetParameterAbstract(const ParameterPtr ¶meter);
|
STATUS SetParameterAbstract(const ParameterPtr ¶meter);
|
||||||
|
|
|
@ -179,23 +179,23 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) {
|
||||||
if (!utils::isa<ValueNodePtr>(anf_node)) {
|
if (!utils::isa<ValueNodePtr>(anf_node)) {
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
auto valueNode = anf_node->cast<ValueNodePtr>();
|
auto value_node = anf_node->cast<ValueNodePtr>();
|
||||||
if (valueNode->abstract() == nullptr) {
|
if (value_node->abstract() == nullptr) {
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueNode->abstract());
|
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(value_node->abstract());
|
||||||
if (abstractTensor == nullptr) {
|
if (abstract_tensor == nullptr) {
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
}
|
}
|
||||||
auto value = abstractTensor->GetValueTrack();
|
auto value = abstract_tensor->GetValueTrack();
|
||||||
if (value != nullptr && value->isa<tensor::Tensor>()) {
|
if (value != nullptr && value->isa<tensor::Tensor>()) {
|
||||||
if (abstractTensor->element() == nullptr) {
|
if (abstract_tensor->element() == nullptr) {
|
||||||
MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
|
MS_LOG(ERROR) << "abstractTensor->element() is nullptr.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto typePtr = abstractTensor->element()->GetTypeTrack();
|
auto type_ptr = abstract_tensor->element()->GetTypeTrack();
|
||||||
if (typePtr->type_id() == kNumberTypeInt64) {
|
if (type_ptr->type_id() == kNumberTypeInt64) {
|
||||||
auto shape_vector = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape();
|
auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape();
|
||||||
auto dest_tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, shape_vector);
|
auto dest_tensor_info = std::make_shared<tensor::Tensor>(kNumberTypeInt32, shape_vector);
|
||||||
auto *dest_data_buf = reinterpret_cast<int32_t *>(dest_tensor_info->data_c());
|
auto *dest_data_buf = reinterpret_cast<int32_t *>(dest_tensor_info->data_c());
|
||||||
auto src_tensor_info = value->cast<tensor::TensorPtr>();
|
auto src_tensor_info = value->cast<tensor::TensorPtr>();
|
||||||
|
@ -204,10 +204,10 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) {
|
||||||
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
|
for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) {
|
||||||
dest_data_buf[i] = src_data_buf[i];
|
dest_data_buf[i] = src_data_buf[i];
|
||||||
}
|
}
|
||||||
abstractTensor->set_value(dest_tensor_info);
|
abstract_tensor->set_value(dest_tensor_info);
|
||||||
abstractTensor->set_type(TypeIdToType(kNumberTypeInt32));
|
abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32));
|
||||||
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
|
abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
|
||||||
valueNode->set_value(dest_tensor_info);
|
value_node->set_value(dest_tensor_info);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return lite::RET_NO_CHANGE;
|
return lite::RET_NO_CHANGE;
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "ops/transpose.h"
|
#include "ops/transpose.h"
|
||||||
#include "tools/optimizer/common/gllo_utils.h"
|
#include "tools/optimizer/common/gllo_utils.h"
|
||||||
|
#include "tools/common/tensor_util.h"
|
||||||
|
|
||||||
using mindspore::lite::converter::FmkType_CAFFE;
|
using mindspore::lite::converter::FmkType_CAFFE;
|
||||||
using mindspore::lite::converter::FmkType_MS;
|
using mindspore::lite::converter::FmkType_MS;
|
||||||
|
@ -92,9 +93,20 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu
|
||||||
auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm");
|
auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm");
|
||||||
auto prim = std::make_shared<ops::Transpose>();
|
auto prim = std::make_shared<ops::Transpose>();
|
||||||
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
|
auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node});
|
||||||
auto type_ptr = TypeIdToType(kTypeUnknown);
|
if (!weight_node->has_default()) {
|
||||||
std::vector<int64_t> shape_vector;
|
MS_LOG(DEBUG) << "Weight parameter should has default parameter.";
|
||||||
auto abstract = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto weight_tensor = weight_node->default_param()->cast<tensor::TensorPtr>();
|
||||||
|
if (weight_tensor == nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor.";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto abstract = lite::CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type());
|
||||||
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
transpose_node->set_abstract(abstract);
|
transpose_node->set_abstract(abstract);
|
||||||
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post");
|
transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post");
|
||||||
for (auto &adjust_node : adjust_nodes) {
|
for (auto &adjust_node : adjust_nodes) {
|
||||||
|
@ -177,11 +189,14 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
auto type_id = static_cast<TypeId>(weight_value->data_type());
|
||||||
auto type_ptr = TypeIdToType(type_id);
|
|
||||||
auto shape = weight_value->shape();
|
auto shape = weight_value->shape();
|
||||||
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
std::vector<int64_t> shape_vector(shape.begin(), shape.end());
|
||||||
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
|
auto abstract = lite::CreateTensorAbstract(shape_vector, type_id);
|
||||||
weight_node->set_abstract(abstract_tensor);
|
if (abstract == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Create tensor abstarct failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
weight_node->set_abstract(abstract);
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue