From 84e4906c9d01829d35e2f7fe0b57aa4094af9f18 Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 22 Apr 2021 16:50:38 +0800 Subject: [PATCH] fix abstract of parameter --- .../cpu/nnacl/infer/topk_infer.c | 5 +- mindspore/lite/src/lite_model.h | 10 +- mindspore/lite/test/CMakeLists.txt | 1 + mindspore/lite/tools/common/tensor_util.cc | 34 ++ mindspore/lite/tools/common/tensor_util.h | 5 + mindspore/lite/tools/converter/CMakeLists.txt | 1 + mindspore/lite/tools/converter/converter.cc | 2 +- .../tools/converter/graphdef_transform.cc | 166 ++-------- .../graph/batchnorm_convert_scale_pass.h | 4 +- .../graph/dtype_trans_pass.cc | 5 - .../legacy_optimizer/graph/dtype_trans_pass.h | 7 +- .../legacy_optimizer/graph/infershape_pass.h | 8 +- mindspore/lite/tools/converter/model_parser.h | 24 +- mindspore/lite/tools/converter/ops/while.cc | 5 +- .../parser/caffe/caffe_model_parser.cc | 92 ++++-- .../parser/caffe/caffe_model_parser.h | 7 +- .../parser/onnx/onnx_model_parser.cc | 98 ++++-- .../converter/parser/onnx/onnx_model_parser.h | 10 +- .../converter/parser/tf/tf_model_parser.cc | 106 ++++--- .../converter/parser/tf/tf_model_parser.h | 7 +- .../parser/tflite/tflite_model_parser.cc | 65 ++-- .../parser/tflite/tflite_model_parser.h | 7 +- .../lite/tools/optimizer/common/gllo_utils.cc | 18 ++ .../lite/tools/optimizer/common/gllo_utils.h | 4 + .../tools/optimizer/fusion/mul_add_fusion.cc | 294 ++++++++++++++++++ .../tools/optimizer/fusion/mul_add_fusion.h | 53 ++++ .../fusion/tf_bidirection_gru_fusion.cc | 7 +- .../fusion/tflite_lstm_cell_fusion.cc | 18 +- .../optimizer/graph/functionalize_while.cc | 21 +- .../graph/group_depthwise_op_convert_pass.cc | 10 +- .../tools/optimizer/graph/infershape_pass.cc | 49 ++- .../tools/optimizer/graph/infershape_pass.h | 2 +- .../optimizer/graph/mindir_adjust_pass.cc | 26 +- .../graph/weight_format_transform_pass.cc | 27 +- 34 files changed, 797 insertions(+), 401 deletions(-) create mode 100644 mindspore/lite/tools/optimizer/fusion/mul_add_fusion.cc create mode 100644 mindspore/lite/tools/optimizer/fusion/mul_add_fusion.h diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/topk_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/topk_infer.c index 4885334049d..59bc152ae3d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/topk_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/topk_infer.c @@ -38,9 +38,12 @@ int TopKInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } + const TensorC *input_k_tensor = inputs[1]; + if (input_k_tensor->data_ == NULL) { + return NNACL_INFER_INVALID; + } TopkParameter *param = (TopkParameter *)parameter; - const TensorC *input_k_tensor = inputs[1]; param->k_ = ((int32_t *)input_k_tensor->data_)[0]; int out_shape[MAX_SHAPE_SIZE]; diff --git a/mindspore/lite/src/lite_model.h b/mindspore/lite/src/lite_model.h index 944de3db389..8877f5637c5 100644 --- a/mindspore/lite/src/lite_model.h +++ b/mindspore/lite/src/lite_model.h @@ -75,12 +75,14 @@ class LiteModel : public Model { } else { node->name_ = c_node->name()->c_str(); } - auto count = c_node->inputIndex()->size(); - for (uint32_t j = 0; j < count; ++j) { - node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs(j))); + if (c_node->inputIndex() != nullptr) { + auto count = c_node->inputIndex()->size(); + for (uint32_t j = 0; j < count; ++j) { + node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs(j))); + } } if (c_node->outputIndex() != nullptr) { - count = c_node->outputIndex()->size(); + auto count = c_node->outputIndex()->size(); for (uint32_t j = 0; j < count; ++j) { node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs(j))); } diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 8a47bff19d1..a3a755fa871 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -247,6 +247,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/matmul_add_fusion.cc + ${LITE_DIR}/tools/optimizer/fusion/mul_add_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/tf_gelu_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/onnx_gelu_fusion.cc diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index ca705654182..99e2a8ac236 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -73,6 +73,40 @@ tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std return tensor_info; } +AbstractBasePtr CreateTensorAbstract(const std::vector &shape, TypeId data_type) { + auto tensor_info = CreateTensorInfo(nullptr, 0, shape, data_type); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return nullptr; + } + auto abstract = tensor_info->ToAbstract(); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return nullptr; + } + return abstract; +} + +int SetParameterAbstractAndParam(const ParameterPtr ¶meter, const void *data, size_t data_size, + const std::vector &shape, TypeId data_type) { + if (parameter == nullptr) { + MS_LOG(ERROR) << "Input parameter is nullptr"; + return RET_INPUT_PARAM_INVALID; + } + auto tensor_info = CreateTensorInfo(data, data_size, shape, data_type); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return RET_ERROR; + } + auto abstract = tensor_info->ToAbstract(); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + parameter->set_abstract(abstract); + return RET_OK; +} + int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size) { if (tensor_info == nullptr) { MS_LOG(ERROR) << "tensor info is nullptr."; diff --git a/mindspore/lite/tools/common/tensor_util.h b/mindspore/lite/tools/common/tensor_util.h index 6c5d18fd5a5..d69568e2be9 100644 --- a/mindspore/lite/tools/common/tensor_util.h +++ b/mindspore/lite/tools/common/tensor_util.h @@ -46,6 +46,11 @@ std::unique_ptr GetTensorQuantParam(const std::unique_ptr tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector &shape, TypeId data_type); +AbstractBasePtr CreateTensorAbstract(const std::vector &shape, TypeId data_type); + +int SetParameterAbstractAndParam(const ParameterPtr ¶meter, const void *data, size_t data_size, + const std::vector &shape, TypeId data_type); + int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size); std::unique_ptr CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info, diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 0807d5ed3ea..704b000a43c 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -54,6 +54,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/tf_bidirection_gru_fusion.cc ../optimizer/fusion/tf_bidirection_gru_cf_fusion.cc ../optimizer/fusion/matmul_add_fusion.cc + ../optimizer/fusion/mul_add_fusion.cc ../optimizer/fusion/gelu_fusion.cc ../optimizer/fusion/tf_gelu_fusion.cc ../optimizer/fusion/onnx_gelu_fusion.cc diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 9dc2a32f6ed..88dc7841cde 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -70,7 +70,7 @@ MetaGraphT *Converter::Convert(const std::unique_ptr &flag) { } MS_LOG(INFO) << "Run anfTransform success"; - // protobuf -> flatbuf + // protobuf -> flatbuffer auto meta_graph = Export(graph, false, false, flag->trainModel); if (meta_graph == nullptr) { MS_LOG(ERROR) << "Export to meta graph return nullptr"; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index d612ffddb87..cba5dede3ff 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -39,7 +39,6 @@ using std::string; namespace mindspore::lite { - std::vector GraphDefTransform::GetGraphNodes() { std::vector old_nodes{}; old_nodes.resize(graph_defT_->nodes.size()); @@ -71,54 +70,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // generate and infer quant parameters - { - Optimizer infer_quant_param_pass; - infer_quant_param_pass.AddPass(new (std::nothrow) TopologicalSortPass()); - infer_quant_param_pass.AddPass(new (std::nothrow) InferQuantParamPass()); - status = infer_quant_param_pass.Run(graph_defT_); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run infer_quant_param_pass graphPasses Failed"; - return status; - } - } - - { - // format transform - // init old node indices - auto old_nodes = GetGraphNodes(); - - Optimizer format_trans_optimizer; - format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - format_trans_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - if (ctx.fmk != converter::FmkType_TF) { - auto infer_shape_pass = new (std::nothrow) InferShapePass(); - if (infer_shape_pass == nullptr) { - MS_LOG(ERROR) << "new InferShapePass failed"; - return RET_MEMORY_FAILED; - } - infer_shape_pass->set_fmk_type(ctx.fmk); - format_trans_optimizer.AddPass(infer_shape_pass); - } - status = format_trans_optimizer.Run(graph_defT_); - if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { - MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; - return status; - } - } - { - // init old node indices - auto old_nodes = GetGraphNodes(); - Optimizer format_trans_optimizer; - format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - status = format_trans_optimizer.Run(graph_defT_); - if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { - MS_LOG(ERROR) << "Run format_trans_optimizer graphPasses Failed"; - return status; - } - } - + // format transpose global optimize { // init old node indices auto old_nodes = GetGraphNodes(); @@ -134,20 +86,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // postconvert pass - { + // node replace + if (!ctx.trainModel) { // init old node indices auto old_nodes = GetGraphNodes(); Optimizer replace_optimizer; - if (!ctx.trainModel) { - auto batch_norm_scale_pass = new (std::nothrow) BatchNormConvertScalePass(); - if (batch_norm_scale_pass == nullptr) { - MS_LOG(ERROR) << "new batch_norm_scale_pass failed."; - return RET_ERROR; - } - batch_norm_scale_pass->SetFmk(ctx.fmk); - replace_optimizer.AddPass(batch_norm_scale_pass); - } + replace_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); + replace_optimizer.AddPass(new (std::nothrow) BatchNormConvertScalePass(ctx.fmk)); replace_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); replace_optimizer.AddPass(new SubgraphNodePass(old_nodes)); status = replace_optimizer.Run(graph_defT_); @@ -157,6 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } + // node fusion { // init old node indices auto old_nodes = GetGraphNodes(); @@ -171,19 +117,14 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // do quantization + // quantization if (ctx.fmk != converter::FmkType_TF) { // init old node indices auto old_nodes = GetGraphNodes(); Optimizer tensor_quant_optimizer; tensor_quant_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - auto infer_shape_pass = new (std::nothrow) InferShapePass(); - if (infer_shape_pass == nullptr) { - MS_LOG(ERROR) << "new InferShapePass failed"; - return RET_MEMORY_FAILED; - } - infer_shape_pass->set_fmk_type(ctx.fmk); - tensor_quant_optimizer.AddPass(infer_shape_pass); + tensor_quant_optimizer.AddPass(new (std::nothrow) InferQuantParamPass()); + tensor_quant_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); tensor_quant_optimizer.AddPass(new (std::nothrow) TensorQuantPass()); tensor_quant_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); status = tensor_quant_optimizer.Run(graph_defT_); @@ -193,38 +134,17 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // insert quantNode and deQuantNode + // quantization if (ctx.fmk != converter::FmkType_TF) { // init old node indices - auto old_nodes = GetGraphNodes(); Optimizer quant_node_optimizer; - quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); quant_node_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - auto infer_shape_pass = new (std::nothrow) InferShapePass(); - if (infer_shape_pass == nullptr) { - MS_LOG(ERROR) << "new InferShapePass failed"; - return RET_MEMORY_FAILED; - } - infer_shape_pass->set_fmk_type(ctx.fmk); - quant_node_optimizer.AddPass(infer_shape_pass); - status = quant_node_optimizer.Run(graph_defT_); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; - return status; - } - auto old_nodes2 = GetGraphNodes(); - quant_node_optimizer.AddPass(new (std::nothrow) InferQuantParamPass()); - auto dtype_trans_pass = new (std::nothrow) DTypeTransPass(); - if (dtype_trans_pass == nullptr) { - MS_LOG(ERROR) << "new dtype_trans_pass failed"; - return RET_MEMORY_FAILED; - } - dtype_trans_pass->set_input_data_dtype(ctx.inputDataType); - dtype_trans_pass->set_output_data_dtype(ctx.outputDataType); - quant_node_optimizer.AddPass(dtype_trans_pass); + auto old_nodes = GetGraphNodes(); + quant_node_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); + quant_node_optimizer.AddPass(new (std::nothrow) DTypeTransPass(ctx.inputDataType, ctx.outputDataType)); quant_node_optimizer.AddPass(new (std::nothrow) QuantCastFusionPass()); quant_node_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); - quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes2)); + quant_node_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); status = quant_node_optimizer.Run(graph_defT_); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "Run quant_node_optimizer graphPasses Failed"; @@ -232,7 +152,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // switch pass + // controlflow pass { // init old node indices auto old_nodes = GetGraphNodes(); @@ -240,6 +160,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { switch_optimizer.AddPass(new (std::nothrow) SwitchPass()); switch_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); switch_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + switch_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); status = switch_optimizer.Run(graph_defT_); if (status != RET_OK && status != RET_NO_CHANGE) { MS_LOG(ERROR) << "Run switch_optimizer Failed"; @@ -247,34 +168,11 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } } - // subgraph tensor pass - { - Optimizer subgraph_tensor_optimizer; - subgraph_tensor_optimizer.AddPass(new (std::nothrow) SubgraphTensorPass()); - status = subgraph_tensor_optimizer.Run(graph_defT_); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run subgraph tensor pass Failed"; - return status; - } - } - - // tensor name - { - // init old node indices - auto old_nodes = GetGraphNodes(); - Optimizer name_optimizer; - name_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); - name_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); - name_optimizer.AddPass(new (std::nothrow) TensorNamePass()); - status = name_optimizer.Run(graph_defT_); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run name_optimizer graphPasses Failed"; - return status; - } - } - { Optimizer nested_loop_optimizer; + auto old_nodes = GetGraphNodes(); + nested_loop_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); + nested_loop_optimizer.AddPass(new (std::nothrow) TopologicalSortPass()); nested_loop_optimizer.AddPass(new (std::nothrow) NestedLoopExpandPass()); status = nested_loop_optimizer.Run(graph_defT_); if (status != RET_OK && status != RET_NO_CHANGE) { @@ -284,30 +182,16 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { } { - Optimizer quant_param_optimizer; - quant_param_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); - status = quant_param_optimizer.Run(graph_defT_); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run quant_param_optimizer graphPasses Failed"; - return status; - } - } - - { - Optimizer infer_shape_optimizer; - auto infer_shape_pass = new (std::nothrow) InferShapePass(); - if (infer_shape_pass == nullptr) { - MS_LOG(ERROR) << "new InferShapePass failed"; - return RET_MEMORY_FAILED; - } - infer_shape_pass->set_fmk_type(ctx.fmk); - infer_shape_optimizer.AddPass(infer_shape_pass); - status = infer_shape_optimizer.Run(graph_defT_); + Optimizer forming_model_optimizer; + forming_model_optimizer.AddPass(new (std::nothrow) InferShapePass(ctx.fmk)); + forming_model_optimizer.AddPass(new (std::nothrow) SetUnusedQuantParamToDefaultPass()); + forming_model_optimizer.AddPass(new (std::nothrow) TensorNamePass()); + status = forming_model_optimizer.Run(graph_defT_); if (status != RET_OK) { MS_LOG(ERROR) << "Run InferShapeOptimizer graphPasses Failed."; return status; } } return RET_OK; -} // namespace mindspore::lite +} } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h index eeab6e3edaa..3844f660975 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h @@ -36,14 +36,12 @@ struct BNWeightTensors { }; class BatchNormConvertScalePass : public GraphPass { public: - BatchNormConvertScalePass() = default; + explicit BatchNormConvertScalePass(converter::FmkType fmk) : fmkType(fmk) {} ~BatchNormConvertScalePass() = default; STATUS Run(MetaGraphT *graph) override; - void SetFmk(converter::FmkType fmk) { this->fmkType = fmk; } - protected: STATUS GetTransParam(MetaGraphT *graph, const std::unique_ptr &bnNode); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index f036c95eb3d..5e4016bcdcf 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -276,10 +276,5 @@ NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIte return InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, castOpCopyer); } - -void DTypeTransPass::set_input_data_dtype(TypeId data_type) { this->input_data_dtype = data_type; } - -void DTypeTransPass::set_output_data_dtype(TypeId data_type) { this->output_data_dtype = data_type; } - } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h index f595809d630..3592a3dafe0 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.h @@ -30,16 +30,13 @@ enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 } class DTypeTransPass : public GraphPass { public: - DTypeTransPass() : id_(0) {} + DTypeTransPass(TypeId model_input_data_type, TypeId model_output_data_type) + : id_(0), input_data_dtype(model_input_data_type), output_data_dtype(model_output_data_type) {} ~DTypeTransPass() override = default; STATUS Run(schema::MetaGraphT *graph) override; - void set_input_data_dtype(TypeId data_type); - - void set_output_data_dtype(TypeId dataType); - private: STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h index 8e3733b556a..d5fda95ec3c 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/infershape_pass.h @@ -39,14 +39,10 @@ struct InferTensor { class InferShapePass : public GraphPass { public: - InferShapePass() = default; - - ~InferShapePass() = default; - + explicit InferShapePass(converter::FmkType fmk_type) : fmk_type_(fmk_type) {} + ~InferShapePass() override = default; STATUS Run(MetaGraphT *graph) override; - void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; } - private: void InitSearchTensor(MetaGraphT *graph); void AddNextInferShapeNode(std::vector output_tensor_node_indexes, size_t index); diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index d7eb1c41637..308855958e1 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -34,8 +34,28 @@ class ModelParser { virtual ~ModelParser() = default; - virtual FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) = 0; + FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, const QuantType &quant_type) { + auto ret = ParseToFuncGraph(model_file, weight_file, quant_type); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Parse to func graph failed : " << ret; + return nullptr; + } + ret = PostAdjust(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Adjust func graph failed : " << ret; + return nullptr; + } + return this->res_graph_; + } + + protected: + virtual int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) = 0; + + virtual int PostAdjust() = 0; + + protected: + FuncGraphPtr res_graph_ = nullptr; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/ops/while.cc b/mindspore/lite/tools/converter/ops/while.cc index b8cc501143b..b81cca024ad 100644 --- a/mindspore/lite/tools/converter/ops/while.cc +++ b/mindspore/lite/tools/converter/ops/while.cc @@ -15,6 +15,7 @@ */ #include +#include "tools/common/tensor_util.h" #include "tools/converter/ops/while.h" #include "utils/check_convert_utils.h" #include "abstract/primitive_infer_map.h" @@ -55,7 +56,9 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP AbstractBasePtrList output; for (int64_t i = 0; i < (int64_t)input_args.size(); i++) { auto shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[i]->BuildShape())[kShape]; - output.push_back(std::make_shared(input_args[i]->BuildType(), shape)); + auto abstract_tensor = lite::CreateTensorAbstract(shape, input_args[i]->BuildType()->type_id()); + MS_EXCEPTION_IF_NULL(abstract_tensor); + output.push_back(abstract_tensor); } return std::make_shared(output); } diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 40a97eb5a47..61f8744eb1f 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -41,34 +41,34 @@ CaffeModelParser::CaffeModelParser() = default; CaffeModelParser::~CaffeModelParser() = default; -FuncGraphPtr CaffeModelParser::Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { +int CaffeModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { STATUS status = InitOriginModel(model_file, weight_file); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } - func_graph_ptr_ = std::make_shared(); + res_graph_ = std::make_shared(); status = ConvertGraphInputs(); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } status = ConvertLayers(); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } status = ConvertGraphOutputs(); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } - func_graph_ptr_->set_attr("graph_name", MakeValue("main_graph")); - func_graph_ptr_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_CAFFE))); - return func_graph_ptr_; + res_graph_->set_attr("graph_name", MakeValue("main_graph")); + res_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_CAFFE))); + return RET_OK; } STATUS CaffeModelParser::ConvertLayers() { @@ -134,7 +134,7 @@ STATUS CaffeModelParser::ConvertLayers() { std::vector op_inputs = {NewValueNode(std::shared_ptr(primitive_c))}; op_inputs.insert(op_inputs.end(), input_nodes.begin(), input_nodes.end()); op_inputs.insert(op_inputs.end(), const_parameters.begin(), const_parameters.end()); - auto new_cnode = func_graph_ptr_->NewCNode(op_inputs); + auto new_cnode = res_graph_->NewCNode(op_inputs); new_cnode->set_fullname_with_scope(layer.name()); // convert outputs @@ -194,14 +194,17 @@ STATUS CaffeModelParser::ConvertGraphInputs() { for (int i = 0; i < caffe_model_.layer_size(); i++) { auto layer = caffe_model_.layer(i); if (layer.type() == "Input") { - auto parameter = func_graph_ptr_->add_parameter(); + auto parameter = res_graph_->add_parameter(); std::vector shape; for (int j = 0; j < layer.input_param().shape(0).dim_size(); j++) { shape.push_back(layer.input_param().shape(0).dim(j)); } - auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); - auto abstract_tensor = std::make_shared(type_ptr, shape); - parameter->set_abstract(abstract_tensor); + auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + parameter->set_abstract(abstract); parameter->set_name("graph_input-" + std::to_string(i)); nodes_.insert(std::pair(layer.top(0), parameter)); } @@ -220,10 +223,13 @@ STATUS CaffeModelParser::ConvertGraphInputs() { shape.push_back(caffe_model_.input_dim(j)); } } - auto parameter = func_graph_ptr_->add_parameter(); - auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); - auto abstract_tensor = std::make_shared(type_ptr, shape); - parameter->set_abstract(abstract_tensor); + auto parameter = res_graph_->add_parameter(); + auto abstract = CreateTensorAbstract(shape, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + parameter->set_abstract(abstract); parameter->set_name("graph_input-" + caffe_model_.input(i)); nodes_.insert(std::pair(caffe_model_.input(i), parameter)); } @@ -234,10 +240,18 @@ STATUS CaffeModelParser::ConvertGraphInputs() { for (int j = 0; j < shape.dim_size(); j++) { shape_vector.push_back(shape.dim(j)); } - auto parameter = func_graph_ptr_->add_parameter(); - auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - parameter->set_abstract(abstract_tensor); + auto parameter = res_graph_->add_parameter(); + auto tensor_info = CreateTensorInfo(nullptr, 0, shape_vector, kNumberTypeFloat32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return RET_ERROR; + } + auto abstract = tensor_info->ToAbstract(); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + parameter->set_abstract(abstract); parameter->set_name("graph_input-" + caffe_model_.input(i)); nodes_.insert(std::pair(caffe_model_.input(i), parameter)); } @@ -265,7 +279,7 @@ STATUS CaffeModelParser::ConvertGraphOutputs() { auto cnode = nodes_.find(output_node)->second; make_tuple_inputs.emplace_back(cnode); } - auto make_tuple_cnode = func_graph_ptr_->NewCNode(make_tuple_inputs); + auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs); make_tuple_cnode->set_fullname_with_scope("return tuple"); std::vector op_inputs; @@ -277,9 +291,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() { auto value_node = NewValueNode(return_prim_ptr); op_inputs.emplace_back(value_node); op_inputs.emplace_back(make_tuple_cnode); - auto cnode = func_graph_ptr_->NewCNode(op_inputs); + auto cnode = res_graph_->NewCNode(op_inputs); cnode->set_fullname_with_scope("Return"); - func_graph_ptr_->set_return(cnode); + res_graph_->set_return(cnode); } else { auto returnPrim = std::make_shared(); if (returnPrim == nullptr) { @@ -298,9 +312,9 @@ STATUS CaffeModelParser::ConvertGraphOutputs() { return RET_NOT_FIND_OP; } opInputs.emplace_back(cnode); - auto returnCnode = func_graph_ptr_->NewCNode(opInputs); + auto returnCnode = res_graph_->NewCNode(opInputs); returnCnode->set_fullname_with_scope("Return"); - func_graph_ptr_->set_return(returnCnode); + res_graph_->set_return(returnCnode); } return RET_OK; } @@ -333,7 +347,7 @@ STATUS CaffeModelParser::ConvertBlobs(const caffe::LayerParameter &layer, std::v ConvertShape(layer.blobs(i), &shape); // cal Weight num - auto parameter = func_graph_ptr_->add_parameter(); + auto parameter = res_graph_->add_parameter(); auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); std::vector shape_vector; (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), @@ -402,17 +416,25 @@ STATUS CaffeModelParser::ConvertBottom(const caffe::LayerParameter &layer, std:: } STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CNodePtr &cnode) { - auto type_ptr = TypeIdToType(TypeId::kNumberTypeFloat32); - std::vector shape_vector; if (layer.top_size() == 1) { - cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + cnode->set_abstract(abstract); nodes_[layer.top(0)] = cnode; return RET_OK; } AbstractBasePtrList abstract_list; for (int i = 0; i < layer.top_size(); i++) { - abstract_list.emplace_back(std::make_shared(type_ptr, shape_vector)); + auto abstract = CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract); auto tuple_get_item_prim_ptr = std::make_shared(); if (tuple_get_item_prim_ptr == nullptr) { MS_LOG(ERROR) << "new TupleGetItem failed"; @@ -421,7 +443,7 @@ STATUS CaffeModelParser::ConvertTop(const caffe::LayerParameter &layer, const CN auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); auto get_item_value = NewValueNode(MakeValue(i)); std::vector inputs{tuple_get_item_prim, cnode, get_item_value}; - CNodePtr get_item_cnode = func_graph_ptr_->NewCNode(inputs); + CNodePtr get_item_cnode = res_graph_->NewCNode(inputs); get_item_cnode->set_fullname_with_scope(layer.top(i)); nodes_[layer.top(i)] = get_item_cnode; } @@ -446,4 +468,6 @@ std::string CaffeModelParser::GetOriginLayerName(const std::string &layer_name) } return layer.name(); } + +int CaffeModelParser::PostAdjust() { return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h index 8eefd3cfab9..1faf493a5bc 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h @@ -32,8 +32,10 @@ class CaffeModelParser : public ModelParser { ~CaffeModelParser() override; - FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; + int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) override; + + int PostAdjust() override; private: STATUS InitOriginModel(const std::string &model_file, const std::string &weight_file); @@ -59,7 +61,6 @@ class CaffeModelParser : public ModelParser { caffe::NetParameter caffe_weight_; std::unordered_map caffe_layers_; std::unordered_map nodes_; - FuncGraphPtr func_graph_ptr_; }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index ffabac69045..a68c25d0fe1 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -45,31 +45,31 @@ static const std::unordered_map TYPE_MAP = { {onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}, {onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}}; -FuncGraphPtr OnnxModelParser::Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { +int OnnxModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { NotSupportOp::GetInstance()->set_fmk_type("ONNX"); - anf_root_graph_ = std::make_shared(); + res_graph_ = std::make_shared(); auto status = InitOriginModel(model_file); if (RET_OK != status) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); MS_LOG(ERROR) << "init origin model failed."; - return nullptr; + return status; } - status = ConvertOnnxGraph(onnx_root_graph_, anf_root_graph_, &anf_nodes_map_, {}, "root_node"); + status = ConvertOnnxGraph(onnx_root_graph_, res_graph_, &anf_nodes_map_, {}, "root_node"); if (RET_OK != status) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); MS_LOG(ERROR) << "convert onnx graph failed."; - return nullptr; + return status; } - static auto root_func_manager = Manage(anf_root_graph_); + static auto root_func_manager = Manage(res_graph_); for (auto &subgraph : all_subgraphs_) { subgraph->set_manager(root_func_manager); subgraph->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX))); } - anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); - anf_root_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX))); - return anf_root_graph_; + res_graph_->set_attr("graph_name", MakeValue("main_graph")); + res_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX))); + return RET_OK; } STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { @@ -88,9 +88,9 @@ STATUS OnnxModelParser::InitOriginModel(const std::string &model_file) { OnnxNodeParser::set_opset_version(onnx_model_.opset_import().Get(0).version()); onnx_root_graph_ = onnx_model_.graph(); if (OnnxNodeParser::opset_version() > 15) { - anf_root_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX))); + res_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX))); } else { - anf_root_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX_LOW_VERSION))); + res_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_ONNX_LOW_VERSION))); } return RET_OK; } @@ -170,13 +170,16 @@ STATUS OnnxModelParser::ConvertGraphInputs(const onnx::GraphProto &onnx_graph, c << static_cast(input_value.type().tensor_type().elem_type()); return RET_ERROR; } - auto type_ptr = TypeIdToType(data_type); std::vector shape_vector; auto onnx_shape = input_value.type().tensor_type().shape().dim(); std::transform(onnx_shape.begin(), onnx_shape.end(), std::back_inserter(shape_vector), [](const onnx::TensorShapeProto_Dimension &val) { return static_cast(val.dim_value()); }); std::replace(shape_vector.begin(), shape_vector.end(), 0, -1); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } parameter->set_abstract(abstract_tensor); parameter->set_name(input_value.name()); anf_nodes_map->emplace(input_value.name(), parameter); @@ -490,17 +493,23 @@ STATUS OnnxModelParser::BuildOpOutputs(const onnx::NodeProto &onnx_node, const F return RET_NULL_PTR; } if (onnx_node.output_size() == 1) { - auto type_ptr = TypeIdToType(kNumberTypeFloat32); - std::vector shape_vector; - cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + cnode->set_abstract(abstract_tensor); anf_nodes_map->emplace(onnx_node.output(0), cnode); } else { AbstractBasePtrList abstract_list; int op_idx = 0; for (const auto &output_name : onnx_node.output()) { - std::vector shape_vector; - auto type_ptr = TypeIdToType(kNumberTypeFloat32); - abstract_list.emplace_back(std::make_shared(type_ptr, shape_vector)); + auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract_tensor); auto tuple_get_item_prim_ptr = std::make_shared(); if (tuple_get_item_prim_ptr == nullptr) { MS_LOG(ERROR) << "new TupleGetItem failed"; @@ -687,7 +696,11 @@ ParameterPtr CreateConstParamter(const FuncGraphPtr &anf_graph, int val) { return nullptr; } auto const_node = anf_graph->add_parameter(); - auto const_abstract = std::make_shared(kInt32, std::vector()); + auto const_abstract = CreateTensorAbstract({}, kNumberTypeInt32); + if (const_abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return nullptr; + } const_node->set_abstract(const_abstract); int *tensor_data = new (std::nothrow) int[1]; if (tensor_data == nullptr) { @@ -834,9 +847,16 @@ STATUS OnnxModelParser::AddTensorArrayEdge(const FuncGraphPtr &anf_graph, std::v for (int i = 0; i < act_output_num; i++) { // tensor_array need as root while input auto while_tensor_array_input = anf_root_graph->add_parameter(); - std::vector shape_vector; - auto abstract_tensor = std::make_shared(kTensorType, shape_vector); - auto tensor_info = std::make_shared(kObjectTypeTensorType, shape_vector); + auto tensor_info = CreateTensorInfo(nullptr, 0, {}, kObjectTypeTensorType); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return RET_ERROR; + } + auto abstract_tensor = tensor_info->ToAbstract(); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } while_tensor_array_input->set_abstract(abstract_tensor); while_tensor_array_input->set_default_param(tensor_info); while_tensor_array_input->set_name(loop_node_name + "_scan_outputs_tensorarray"); @@ -975,7 +995,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf auto input_paramter = cond_graph->add_parameter(); input_paramter->set_name(cond_graph_name + "_input_" + std::to_string(i) + "_parameter"); auto root_while_inputs = root_while_node->cast()->inputs(); - auto input_abstract = std::make_shared(kInt32, std::vector()); + auto input_abstract = CreateTensorAbstract({}, kNumberTypeInt32); + if (input_abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } input_paramter->set_abstract(input_abstract); if (i == 0) { auto zero_parameter = CreateConstParamter(cond_graph, 0); @@ -987,7 +1011,11 @@ STATUS OnnxModelParser::BuildCondGraph(const FuncGraphPtr &cond_graph, const Anf MS_LOG(ERROR) << "new cnode error"; return RET_ERROR; } - auto less_abstract = std::make_shared(kBool, std::vector()); + auto less_abstract = CreateTensorAbstract({}, kNumberTypeBool); + if (less_abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } less_cnode->set_abstract(less_abstract); less_cnode->set_fullname_with_scope(cond_graph_name + "_less_cnode"); } @@ -1020,12 +1048,11 @@ STATUS OnnxModelParser::BuildParameterNodeForQuantParam(const void *data, const MS_LOG(ERROR) << "quant param type don't support."; return RET_NOT_SUPPORT; } - std::vector shape_vector; - auto parameter_node = anf_root_graph_->add_parameter(); - auto abstract_tensor = std::make_shared(TypeIdToType(type), shape_vector); + auto parameter_node = res_graph_->add_parameter(); + auto abstract_tensor = CreateTensorAbstract({}, type); if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "new abstract_tensor failed"; - return RET_MEMORY_FAILED; + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; } parameter_node->set_abstract(abstract_tensor); parameter_node->set_name(name); @@ -1051,9 +1078,12 @@ STATUS OnnxModelParser::BuildParameterNode(const ParameterPtr ¶meter_node, c MS_LOG(ERROR) << "not support onnx data type " << static_cast(tensor.data_type()); return RET_ERROR; } - auto type_ptr = TypeIdToType(data_type); std::vector shape_vector(tensor.dims().begin(), tensor.dims().end()); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + auto abstract_tensor = CreateTensorAbstract(shape_vector, data_type); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } parameter_node->set_abstract(abstract_tensor); parameter_node->set_name(tensor.name()); @@ -1142,5 +1172,7 @@ TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type } return iter->second; } + +int OnnxModelParser::PostAdjust() { return 0; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 6bc638a1cf0..463a6dfefad 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -40,14 +40,17 @@ class OnnxModelParser : public ModelParser { ~OnnxModelParser() override = default; - FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; + int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) override; + + int PostAdjust() override; + static TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); static STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_const_tensor, const tensor::TensorPtr ¶m_value_lite); + STATUS InitOriginModel(const std::string &model_file); private: - STATUS InitOriginModel(const std::string &model_file); STATUS ConvertNodes(const onnx::GraphProto &onnx_graph, const FuncGraphPtr &func_graph_ptr, std::unordered_map *anf_nodes_map, std::vector *graph_inputs, const std::string &root_node_name); @@ -94,7 +97,6 @@ class OnnxModelParser : public ModelParser { std::unordered_map anf_nodes_map_; std::unordered_map *> control_nodes_map_; std::unordered_map child_root_map_; // for nest control flow node - FuncGraphPtr anf_root_graph_ = nullptr; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc index 993fb069199..f9c856d0836 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.cc @@ -417,18 +417,17 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa type = TensorFlowUtils::GetTFDataType(attr_value.type()); } - std::vector shape; + std::vector shape; if (TensorFlowUtils::FindAttrValue(node, "shape", &attr_value)) { auto &shape_attr = attr_value.shape(); for (int i = 0; i < shape_attr.dim_size(); ++i) { shape.push_back(shape_attr.dim(i).size()); } } - std::vector shape_vector(shape.begin(), shape.end()); if (TensorFlowUtils::FindAttrValue(node, "value", &attr_value)) { MS_LOG(INFO) << "Found value attr, means it has default value"; - auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape_vector); + auto status = ConvertConstTensor(node, attr_value, type, parameter, &shape); if (status != RET_OK) { MS_LOG(ERROR) << "convert const tensor failed."; return status; @@ -437,10 +436,10 @@ STATUS TFModelParser::ConvertParameter(const tensorflow::NodeDef &node, const Pa graph_input_names_.emplace_back(node.name()); // only root graph need set graph input names } - auto type_ptr = TypeIdToType(type == kNumberTypeInt64 ? kNumberTypeInt32 : type); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + type = (type == kNumberTypeInt64) ? kNumberTypeInt32 : type; + auto abstract_tensor = CreateTensorAbstract(shape, type); if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "abstract_tensor is nullptr"; + MS_LOG(ERROR) << "Create tensor abstarct failed"; return RET_ERROR; } parameter->set_name(node.name()); @@ -473,51 +472,51 @@ STATUS TFModelParser::ConvertGraphInputsAndConsts( } return RET_OK; } -FuncGraphPtr paserTfFuction() { return nullptr; } -FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType) { + +int TFModelParser::ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { NotSupportOp::GetInstance()->set_fmk_type("TF"); auto status = ValidateFileStr(modelFile, ".pb"); if (status != RET_OK) { MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.pb"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } tf_root_graph_ = std::make_unique(); if (tf_root_graph_ == nullptr) { MS_LOG(ERROR) << "tf_root_graph_ is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; + return status; } status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), tf_root_graph_.get()); if (status != RET_OK) { MS_LOG(ERROR) << "Open modelFile for TF converter failed!"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } - anf_root_graph_ = std::make_shared(); - if (anf_root_graph_ == nullptr) { + res_graph_ = std::make_shared(); + if (res_graph_ == nullptr) { MS_LOG(ERROR) << "funGraphPtr is nullptr"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); - return nullptr; + return status; } - anf_root_graph_->set_attr("graph_name", MakeValue("main_graph")); - anf_root_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_TF))); + res_graph_->set_attr("graph_name", MakeValue("main_graph")); + res_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_TF))); for (int i = 0; i < tf_root_graph_->node_size(); i++) { auto &node_def = tf_root_graph_->node(i); tf_root_graph_nodes_[node_def.name()] = &node_def; } - status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); + status = ConvertGraphInputsAndConsts(tf_root_graph_nodes_, res_graph_, &anf_root_node_map_); if (status != RET_OK) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } bool success_flag = true; for (int i = 0; i < tf_root_graph_->node_size(); i++) { auto &node_def = tf_root_graph_->node(i); - status = ConvertOps(node_def, tf_root_graph_nodes_, anf_root_graph_, &anf_root_node_map_); + status = ConvertOps(node_def, tf_root_graph_nodes_, res_graph_, &anf_root_node_map_); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); if (status != RET_OK) { success_flag = false; @@ -525,7 +524,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin } if (!success_flag) { MS_LOG(ERROR) << "Convert ops failed."; - return nullptr; + return RET_ERROR; } if (!nodes_with_null_input_.empty()) { @@ -533,7 +532,7 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin if (status != RET_OK) { MS_LOG(ERROR) << "Connect null inputs failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } } @@ -541,17 +540,17 @@ FuncGraphPtr TFModelParser::Parse(const std::string &modelFile, const std::strin if (status != RET_OK) { MS_LOG(ERROR) << "Convert graph outputs failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } status = ConvertSubgraph(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert subgraph failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } - return anf_root_graph_; + return RET_OK; } STATUS TFModelParser::ConvertSubgraphInputs(std::map *tf_sub_node_map, @@ -745,7 +744,7 @@ STATUS TFModelParser::ControlFlowNodePostProcess(const std::mapinputs(); inputs.insert(inputs.begin() + 1, {first_value_node, second_value_node}); - auto new_node = anf_root_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update + auto new_node = res_graph_->NewCNode(inputs); // must create new node, otherwise node_users won't update if (new_node == nullptr) { MS_LOG(ERROR) << "new node failed"; return RET_ERROR; @@ -811,43 +810,46 @@ STATUS TFModelParser::ConvertOutputTensor(const tensorflow::NodeDef &op, const C if (output_size == 0) { return RET_OK; } else if (output_size == 1) { - auto type = kFloat32; - std::vector shape_vector; + auto type = kNumberTypeFloat32; if (IsTensorListOp(anf_node)) { - type = TypeIdToType(kObjectTypeTensorType); + type = kObjectTypeTensorType; } - auto abstract = std::make_shared(type, shape_vector); - if (abstract == nullptr) { - MS_LOG(ERROR) << "create AbstractTensor failed"; + auto abstract_tensor = CreateTensorAbstract({}, type); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; return RET_ERROR; } - anf_node->set_abstract(abstract); + anf_node->set_abstract(abstract_tensor); anf_node_map->insert(std::pair(op.name(), anf_node)); } else { - AbstractBasePtrList abstractList; + AbstractBasePtrList abstract_list; for (int output_idx = 0; output_idx < output_size; output_idx++) { - std::vector shape_vector; - abstractList.emplace_back(std::make_shared(kFloat32, shape_vector)); - auto tupleGetItemPrimPtr = std::make_shared(); - if (tupleGetItemPrimPtr == nullptr) { + auto abstract_tensor = CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract_tensor); + auto tuple_get_item_prim_ptr = std::make_shared(); + if (tuple_get_item_prim_ptr == nullptr) { MS_LOG(ERROR) << "new TupleGetItem failed"; return RET_NULL_PTR; } - auto tupleGetItemPrim = NewValueNode(tupleGetItemPrimPtr); - auto getItemValue = NewValueNode(MakeValue(output_idx)); - std::vector inputs{tupleGetItemPrim, anf_node, getItemValue}; - CNodePtr getItemCNode = anf_graph->NewCNode(inputs); + auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); + auto get_item_value = NewValueNode(MakeValue(output_idx)); + std::vector inputs{tuple_get_item_prim, anf_node, get_item_value}; + CNodePtr get_item_cnode = anf_graph->NewCNode(inputs); std::string output_item_name = anf_node->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); - auto abstract = std::make_shared(kFloat32, shape_vector); - if (abstract == nullptr) { - MS_LOG(ERROR) << "create AbstractTensor failed"; + auto get_item_abstract = CreateTensorAbstract({}, kNumberTypeFloat32); + if (get_item_abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; return RET_ERROR; } - getItemCNode->set_abstract(abstract); - getItemCNode->set_fullname_with_scope(output_item_name); - anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), getItemCNode)); + get_item_cnode->set_abstract(get_item_abstract); + get_item_cnode->set_fullname_with_scope(output_item_name); + anf_node_map->insert(std::pair(op.name() + ":" + std::to_string(output_idx), get_item_cnode)); } - anf_node->set_abstract(std::make_shared(abstractList)); + anf_node->set_abstract(std::make_shared(abstract_list)); } return RET_OK; } @@ -1003,7 +1005,7 @@ STATUS TFModelParser::ConvertRootGraphOutputs() { graph_output_names_.push_back(anf_node->fullname_with_scope()); } } - auto status = MakeAnfGraphOutputs(&output_nodes, anf_root_graph_); + auto status = MakeAnfGraphOutputs(&output_nodes, res_graph_); if (status != RET_OK) { MS_LOG(ERROR) << "make anf graph outputs node error"; return status; @@ -1051,5 +1053,7 @@ STATUS TFModelParser::MakeAnfGraphOutputs(std::vector *output_nodes, } return RET_OK; } + +int TFModelParser::PostAdjust() { return 0; } } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h index 512407a2e5d..08509a82dd8 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tf/tf_model_parser.h @@ -36,9 +36,11 @@ namespace lite { class TFModelParser : public ModelParser { public: TFModelParser() = default; - ~TFModelParser() = default; + ~TFModelParser() override = default; - FuncGraphPtr Parse(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); + int ParseToFuncGraph(const std::string &modelFile, const std::string &weightFile, const QuantType &quantType); + + int PostAdjust() override; private: static STATUS ConvertConstVariant(const tensorflow::TensorProto &tensor_proto, tensor::TensorPtr *tensor_info); @@ -84,7 +86,6 @@ class TFModelParser : public ModelParser { STATUS ConnectNullInput(); - FuncGraphPtr anf_root_graph_; std::unique_ptr tf_root_graph_; // tf root graph def std::map tf_root_graph_nodes_; // tf root graph node map std::unordered_map anf_root_node_map_; diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index b456bfcbd6c..e0878415b7a 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -43,46 +43,46 @@ std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *m return tflite::UnPackModel(tflite_model_buf_); } -FuncGraphPtr TfliteModelParser::Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) { +int TfliteModelParser::ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) { // load graph tflite_model_ = ReadTfliteModel(model_file.c_str()); if (tflite_model_ == nullptr) { MS_LOG(ERROR) << "read tflite model failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; + return RET_GRAPH_FILE_ERR; } if (tflite_model_->subgraphs.size() != 1) { MS_LOG(ERROR) << "read tflite model subgraphs failed"; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_GRAPH_FILE_ERR); - return nullptr; + return RET_GRAPH_FILE_ERR; } - func_graph_ = std::make_shared(); - func_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_TFLITE))); + res_graph_ = std::make_shared(); + res_graph_->set_attr("fmk", MakeValue(static_cast(converter::FmkType_TFLITE))); auto status = ConvertGraphInputs(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert graph inputs failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } status = ConvertOps(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert ops failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } status = ConvertGraphOutputs(); if (status != RET_OK) { MS_LOG(ERROR) << "Convert graph outputs failed."; ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); - return nullptr; + return status; } - func_graph_->set_attr("graph_name", MakeValue("main_graph")); - return func_graph_; + res_graph_->set_attr("graph_name", MakeValue("main_graph")); + return RET_OK; } std::string GetTensorName(size_t index, const tflite::BuiltinOperator &op_type, const std::string &op_name) { @@ -158,7 +158,7 @@ STATUS TfliteModelParser::ConvertOps() { } else { tensor_name = GetTensorName(i, tflite_op_type, op_name); } - auto parameter = func_graph_->add_parameter(); + auto parameter = res_graph_->add_parameter(); status = ConvertConstTensor(input_tensor.get(), parameter, tensor_name); if (status != RET_OK) { MS_LOG(ERROR) << "convert " << op_name << " node: " << input_idx << " const node failed."; @@ -168,7 +168,7 @@ STATUS TfliteModelParser::ConvertOps() { op_inputs.emplace_back(parameter); nodes_.insert(std::pair(input_idx, parameter)); } - auto new_cnode = func_graph_->NewCNode(op_inputs); + auto new_cnode = res_graph_->NewCNode(op_inputs); new_cnode->set_fullname_with_scope(op_name); // parse outputs @@ -284,13 +284,16 @@ STATUS TfliteModelParser::ConvertGraphInputs() { if (tflite_graph_input < 0) { tflite_graph_input = tflite_graph_input + tflite_subgraph->tensors.size(); } - auto parameter = func_graph_->add_parameter(); + auto parameter = res_graph_->add_parameter(); const auto &tensor = tflite_subgraph->tensors.at(tflite_graph_input); std::vector shape_vector; (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); - auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); + auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type)); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } parameter->set_abstract(abstract_tensor); parameter->set_name("graph_input-" + std::to_string(tflite_graph_input)); nodes_.insert(std::pair(tflite_graph_input, parameter)); @@ -318,7 +321,7 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { } make_tuple_inputs.emplace_back(cnode); } - auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); + auto make_tuple_cnode = res_graph_->NewCNode(make_tuple_inputs); make_tuple_cnode->set_fullname_with_scope("return tuple"); std::vector op_inputs; @@ -330,9 +333,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { auto value_node = NewValueNode(return_prim_ptr); op_inputs.emplace_back(value_node); op_inputs.emplace_back(make_tuple_cnode); - auto cnode = func_graph_->NewCNode(op_inputs); + auto cnode = res_graph_->NewCNode(op_inputs); cnode->set_fullname_with_scope("Return"); - func_graph_->set_return(cnode); + res_graph_->set_return(cnode); } else { auto returnPrim = std::make_shared(); if (returnPrim == nullptr) { @@ -350,9 +353,9 @@ STATUS TfliteModelParser::ConvertGraphOutputs() { return RET_NOT_FIND_OP; } op_inputs.emplace_back(cnode); - auto returnCnode = func_graph_->NewCNode(op_inputs); + auto returnCnode = res_graph_->NewCNode(op_inputs); returnCnode->set_fullname_with_scope("Return"); - func_graph_->set_return(returnCnode); + res_graph_->set_return(returnCnode); } return RET_OK; } @@ -436,8 +439,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const std::vector shape_vector; (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); - auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); - dst_cnode->set_abstract(std::make_shared(type_ptr, shape_vector)); + auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type)); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + dst_cnode->set_abstract(abstract_tensor); nodes_.insert(std::pair(op->outputs.front(), dst_cnode)); } else { AbstractBasePtrList abstract_list; @@ -450,8 +457,12 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const std::vector shape_vector; (void)std::transform(tensor->shape.begin(), tensor->shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); - auto type_ptr = TypeIdToType(GetTfliteDataType(tensor->type)); - abstract_list.emplace_back(std::make_shared(type_ptr, shape_vector)); + auto abstract_tensor = CreateTensorAbstract(shape_vector, GetTfliteDataType(tensor->type)); + if (abstract_tensor == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract_tensor); auto tuple_get_item_prim_ptr = std::make_shared(); if (tuple_get_item_prim_ptr == nullptr) { MS_LOG(ERROR) << "new TupleGetItem failed"; @@ -460,7 +471,7 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const auto tuple_get_item_prim = NewValueNode(tuple_get_item_prim_ptr); auto get_item_value = NewValueNode(MakeValue(op_idx)); std::vector inputs{tuple_get_item_prim, dst_cnode, get_item_value}; - CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); + CNodePtr get_item_cnode = res_graph_->NewCNode(inputs); get_item_cnode->set_fullname_with_scope(dst_cnode->fullname_with_scope() + "_getitem_" + std::to_string(op_idx)); nodes_.insert(std::pair(output_idx, get_item_cnode)); op_idx++; @@ -469,4 +480,6 @@ STATUS TfliteModelParser::ConvertOutputTensor(const tflite::OperatorT *op, const } return RET_OK; } + +int TfliteModelParser::PostAdjust() { return 0; } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h index 99e1c0b635a..e88a830cd72 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.h @@ -32,13 +32,14 @@ class TfliteModelParser : public ModelParser { ~TfliteModelParser() override = default; - FuncGraphPtr Parse(const std::string &model_file, const std::string &weight_file, - const QuantType &quant_type) override; + int ParseToFuncGraph(const std::string &model_file, const std::string &weight_file, + const QuantType &quant_type) override; + + int PostAdjust() override; private: std::unordered_map nodes_; std::unique_ptr tflite_model_; - FuncGraphPtr func_graph_; char *tflite_model_buf_ = nullptr; std::unique_ptr ReadTfliteModel(const char *model_path); STATUS ConvertConstTensor(const tflite::TensorT *tensor, const ParameterPtr ¶meter, diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 4ccc20330c1..40047c7fb80 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -399,6 +399,24 @@ int CheckIfCNodeIsNull(const CNodePtr &node) { return lite::RET_OK; } +int CheckIfParameterIsNull(const ParameterPtr &node) { + if (node == nullptr) { + MS_LOG(ERROR) << "The Parameter is null."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return lite::RET_NULL_PTR; + } + return lite::RET_OK; +} + +int CheckIfValueNodeIsNull(const ValueNodePtr &node) { + if (node == nullptr) { + MS_LOG(ERROR) << "The ValueNode is null."; + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return lite::RET_NULL_PTR; + } + return lite::RET_OK; +} + int CheckIfVarIsNull(const VarPtr &var) { if (var == nullptr) { MS_LOG(ERROR) << "The Var is null."; diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index aece7064d30..3bb51fb022e 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -57,6 +57,10 @@ int CheckIfAnfNodeIsNull(const AnfNodePtr &node); int CheckIfCNodeIsNull(const CNodePtr &node); +int CheckIfParameterIsNull(const ParameterPtr &node); + +int CheckIfValueNodeIsNull(const ValueNodePtr &node); + int CheckIfVarIsNull(const VarPtr &var); int CheckInputSize(const CNodePtr &node, int size); diff --git a/mindspore/lite/tools/optimizer/fusion/mul_add_fusion.cc b/mindspore/lite/tools/optimizer/fusion/mul_add_fusion.cc new file mode 100644 index 00000000000..9378d5cfdc3 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/mul_add_fusion.cc @@ -0,0 +1,294 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/optimizer/fusion/mul_add_fusion.h" +#include +#include "ops/fusion/mul_fusion.h" +#include "ops/fusion/add_fusion.h" +#include "ops/fusion/scale_fusion.h" +#include "ops/op_utils.h" +#include "tools/optimizer/common/gllo_utils.h" + +namespace mindspore::opt { +namespace { +constexpr size_t kMulInputsLength = 3; +constexpr size_t kAddInputsLength = 3; +} // namespace + +const BaseRef MulAddFusion::DefinePattern() const { + auto mul_var = std::make_shared(IsSpecifiedNode<&prim::kPrimMulFusion>); + auto add_var = std::make_shared(IsSpecifiedNode<&prim::kPrimAddFusion>); + return VectorRef({add_var, mul_var}); +} + +bool MulAddFusion::ScaleInputShapeValid() const { + MS_ASSERT(scale_tensor_ != nullptr); + MS_ASSERT(bias_tensor_ != nullptr); + auto scale_shape = scale_tensor_->shape_c(); + auto offset_shape = bias_tensor_->shape_c(); + if (mul_input_shape_.size() < scale_shape.size() || scale_shape.size() == 0) { + return false; + } + size_t rank_diff = mul_input_shape_.size() - scale_shape.size(); + for (size_t i = 0; i < scale_shape.size(); ++i) { + if (mul_input_shape_[i + rank_diff] != scale_shape[i]) { + return false; + } + } + if (scale_shape != offset_shape) { + return false; + } + return true; +} + +bool MulAddFusion::CheckMulNode(const FuncGraphPtr &func_graph) const { + MS_ASSERT(func_graph != nullptr); + if (mul_anode_ == nullptr) { + return false; + } + if (IsMultiOutputTensors(func_graph, mul_anode_)) { + MS_LOG(DEBUG) << "Mul op has multi-output"; + return false; + } + auto mul_node = mul_anode_->cast(); + if (!CheckPrimitiveType(mul_node, prim::kPrimMulFusion)) { + MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add"; + return false; + } + auto mul_primitive = GetValueNode>(mul_node->input(0)); + MS_ASSERT(mul_primitive != nullptr); + auto mul_act_type = mul_primitive->get_activation_type(); + if (mul_act_type != ActivationType::NO_ACTIVATION) { + MS_LOG(DEBUG) << "Only support mul node with no activation"; + return false; + } + if (CheckIfCNodeIsNull(mul_node) != lite::RET_OK || CheckInputSize(mul_node, kMulInputsLength) != lite::RET_OK) { + MS_LOG(DEBUG) << "Mul op is null or has error input size"; + return false; + } + // find mul's const input and mul input + AnfNodePtr mul_pre_input_node = nullptr; + AnfNodePtr mul_pre_const_node = nullptr; + auto mul_pre_node_1 = mul_node->input(1); + if (CheckIfAnfNodeIsNull(mul_pre_node_1) != lite::RET_OK) { + MS_LOG(DEBUG) << "Pre-node of mul op is nullptr"; + return false; + } + auto mul_pre_node_2 = mul_node->input(2); + if (CheckIfAnfNodeIsNull(mul_pre_node_2) != lite::RET_OK) { + MS_LOG(DEBUG) << "Pre-node of mul op is nullptr"; + return false; + } + if (utils::isa(mul_pre_node_1) && !utils::isa(mul_pre_node_2)) { + mul_pre_input_node = mul_pre_node_1; + mul_pre_const_node = mul_pre_node_2; + } else if (!utils::isa(mul_pre_node_1) && utils::isa(mul_pre_node_2)) { + mul_pre_input_node = mul_pre_node_1; + mul_pre_const_node = mul_pre_node_2; + } else { + MS_LOG(DEBUG) << "Mul op should has a cnode input and a const input"; + return false; + } + // check mul's const input + tensor::TensorPtr mul_tensor = nullptr; + if (utils::isa(mul_pre_const_node)) { + auto mul_bias_node = mul_pre_const_node->cast(); + MS_ASSERT(mul_bias_node != nullptr); + if (!mul_bias_node->has_default()) { + MS_LOG(DEBUG) << "Const input of mul op should has data"; + return false; + } + mul_tensor = mul_bias_node->default_param()->cast(); + } else if (utils::isa(mul_pre_const_node)) { + auto mul_bias_node = mul_pre_const_node->cast(); + MS_ASSERT(mul_bias_node != nullptr); + if (mul_bias_node->value() == nullptr) { + MS_LOG(DEBUG) << "Const input of mul op should has data"; + return false; + } + mul_tensor = mul_bias_node->value()->cast(); + } else { + MS_ASSERT(false); + } + if (mul_tensor == nullptr) { + MS_LOG(DEBUG) << "Const input of add op should has data"; + return false; + } + mul_input_anode_ = mul_pre_input_node; + mul_const_anode_ = mul_pre_const_node; + scale_tensor_ = mul_tensor; + return true; +} + +bool MulAddFusion::CheckAddNode() const { + if (add_anode_ == nullptr) { + return false; + } + auto add_cnode = add_anode_->cast(); + if (CheckIfCNodeIsNull(add_cnode) != lite::RET_OK || CheckInputSize(add_cnode, kAddInputsLength) != lite::RET_OK) { + MS_LOG(DEBUG) << "Add op is null or has error input size"; + return false; + } + if (!CheckPrimitiveType(add_cnode, prim::kPrimAddFusion)) { + MS_LOG(DEBUG) << "Mul add fusion pass match only mul or add"; + return false; + } + auto add_primitive = GetValueNode>(add_cnode->input(0)); + MS_ASSERT(add_primitive != nullptr); + auto add_act_type = add_primitive->get_activation_type(); + if (add_act_type != ActivationType::RELU && add_act_type != ActivationType::RELU6 && + add_act_type != ActivationType::NO_ACTIVATION) { + MS_LOG(DEBUG) << "Only support add node with relu or relu6 or no activation"; + return false; + } + scale_act_type_ = add_act_type; + // find add's const input and mul input + AnfNodePtr add_pre_input_node = nullptr; + AnfNodePtr add_pre_const_node = nullptr; + auto add_pre_node_1 = add_cnode->input(1); + if (CheckIfAnfNodeIsNull(add_pre_node_1) != lite::RET_OK) { + MS_LOG(DEBUG) << "Pre-node of add op is nullptr"; + return false; + } + auto add_pre_node_2 = add_cnode->input(2); + if (CheckIfAnfNodeIsNull(add_pre_node_2) != lite::RET_OK) { + MS_LOG(DEBUG) << "Pre-node of add op is nullptr"; + return false; + } + if (utils::isa(add_pre_node_1) && !utils::isa(add_pre_node_2)) { + add_pre_input_node = add_pre_node_1; + add_pre_const_node = add_pre_node_2; + } else if (!utils::isa(add_pre_node_1) && utils::isa(add_pre_node_2)) { + add_pre_input_node = add_pre_node_2; + add_pre_const_node = add_pre_node_1; + } else { + MS_LOG(DEBUG) << "Add op should has a cnode input and a const input"; + return false; + } + // check add's const input + tensor::TensorPtr add_tensor = nullptr; + if (utils::isa(add_pre_const_node)) { + auto add_bias_node = add_pre_const_node->cast(); + MS_ASSERT(add_bias_node != nullptr); + if (!add_bias_node->has_default()) { + MS_LOG(DEBUG) << "Const input of add op should has data"; + return false; + } + add_tensor = add_bias_node->default_param()->cast(); + } else if (utils::isa(add_pre_const_node)) { + auto add_bias_node = add_pre_const_node->cast(); + MS_ASSERT(add_bias_node != nullptr); + if (add_bias_node->value() == nullptr) { + MS_LOG(DEBUG) << "Const input of add op should has data"; + return false; + } + add_tensor = add_bias_node->value()->cast(); + } else { + MS_ASSERT(false); + } + if (add_tensor == nullptr) { + MS_LOG(DEBUG) << "Const input of add op should has data"; + return false; + } + mul_anode_ = add_pre_input_node; + add_const_anode_ = add_pre_const_node; + bias_tensor_ = add_tensor; + return true; +} + +bool MulAddFusion::GetMulInputShape() const { + MS_ASSERT(mul_input_anode_ != nullptr); + ShapeVector mul_input_shape; + AbstractBasePtr mul_input_abstract = nullptr; + if (utils::isa(mul_input_anode_)) { + auto mul_input_node = mul_input_anode_->cast(); + MS_ASSERT(mul_bias_node != nullptr); + mul_input_abstract = mul_input_node->abstract(); + } else if (utils::isa(mul_input_anode_)) { + auto mul_input_node = mul_input_anode_->cast(); + MS_ASSERT(mul_input_node != nullptr); + mul_input_abstract = mul_input_node->abstract(); + } else if (utils::isa(mul_input_anode_)) { + auto mul_input_node = mul_input_anode_->cast(); + MS_ASSERT(mul_input_node != nullptr); + mul_input_abstract = mul_input_node->abstract(); + } else { + MS_ASSERT(false); + } + if (mul_input_abstract == nullptr) { + MS_LOG(DEBUG) << "Mul input node has no abstract"; + return false; + } + if (!utils::isa(mul_input_abstract)) { + MS_LOG(DEBUG) << "Abstract of mul input node should be AbstractTensor"; + return false; + } + auto abstract_tensor = utils::cast(mul_input_abstract); + MS_ASSERT(abstract_tensor != nullptr); + MS_ASSERT(abstract_tensor->BuildShape() != nullptr); + if (!utils::isa(abstract_tensor->BuildShape())) { + MS_LOG(DEBUG) << "BuildShape of abstract of mul input node should be ShapePtr"; + return false; + } + mul_input_shape_ = utils::cast(abstract_tensor->BuildShape())->shape(); + return true; +} + +const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + return nullptr; + } + add_anode_ = node; + if (!CheckAddNode()) { + MS_LOG(DEBUG) << "Add op is not suit for mul-add-fusion: " << node->fullname_with_scope(); + return nullptr; + } + MS_ASSERT(mul_anode_ != nullptr); + MS_ASSERT(bias_tensor_ != nullptr); + MS_ASSERT(add_const_anode_ != nullptr); + if (!CheckMulNode(func_graph)) { + MS_LOG(DEBUG) << "Mul op is not suit for mul-add-fusion: " << mul_anode_->fullname_with_scope(); + return nullptr; + } + MS_ASSERT(mul_input_anode_ != nullptr); + MS_ASSERT(scale_tensor_ != nullptr); + MS_ASSERT(mul_const_anode_ != nullptr); + if (!GetMulInputShape()) { + MS_LOG(DEBUG) << "Get input shape of mul op failed"; + return nullptr; + } + // scale requires scale shape tail sub of input shape, scale shape same as bias shape + if (!ScaleInputShapeValid()) { + MS_LOG(DEBUG) << "Check input shape, scale shape and bias shape failed"; + return nullptr; + } + // create scale primitive + auto scale_primitive = new (std::nothrow) mindspore::ops::ScaleFusion(); + if (scale_primitive == nullptr) { + MS_LOG(ERROR) << "new scale primitive failed"; + return nullptr; + } + scale_primitive->set_activation_type(scale_act_type_); + scale_primitive->set_axis(0 - bias_tensor_->shape_c().size()); + // create scale op + auto scale_node = func_graph->NewCNode(std::shared_ptr(scale_primitive), + {mul_input_anode_, mul_const_anode_, add_const_anode_}); + return scale_node; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/fusion/mul_add_fusion.h b/mindspore/lite/tools/optimizer/fusion/mul_add_fusion.h new file mode 100644 index 00000000000..de1d9069db2 --- /dev/null +++ b/mindspore/lite/tools/optimizer/fusion/mul_add_fusion.h @@ -0,0 +1,53 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_MUL_ADD_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace opt { +class MulAddFusion : public PatternProcessPass { + public: + explicit MulAddFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion") + : PatternProcessPass(name, multigraph) {} + ~MulAddFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + bool CheckMulNode(const FuncGraphPtr &func_graph) const; + bool CheckAddNode() const; + bool GetMulInputShape() const; + bool ScaleInputShapeValid() const; + + private: + mutable AnfNodePtr mul_anode_ = nullptr; + mutable AnfNodePtr mul_input_anode_ = nullptr; + mutable AnfNodePtr mul_const_anode_ = nullptr; + mutable ShapeVector mul_input_shape_; + mutable AnfNodePtr add_anode_ = nullptr; + mutable AnfNodePtr add_const_anode_ = nullptr; + mutable tensor::TensorPtr scale_tensor_ = nullptr; + mutable tensor::TensorPtr bias_tensor_ = nullptr; + mutable ActivationType scale_act_type_ = ActivationType::NO_ACTIVATION; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc index 7907ad3890b..f2740b4bc89 100644 --- a/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tf_bidirection_gru_fusion.cc @@ -256,11 +256,12 @@ ParameterPtr TfBidirectionGruFusion::AddDefaultParameter(const FuncGraphPtr &fun auto parameter = func_graph->add_parameter(); parameter->set_name(name); std::vector shape_vector(shape.begin(), shape.end()); - auto abstract_tensor = std::make_shared(TypeIdToType(type), shape_vector); - if (abstract_tensor == nullptr) { + auto abstract = lite::CreateTensorAbstract(shape_vector, type); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; return nullptr; } - parameter->set_abstract(abstract_tensor); + parameter->set_abstract(abstract); auto gate_weight_default = std::make_shared(type, shape_vector); if (gate_weight_default == nullptr) { diff --git a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc index 357195eef52..93b6bbd7dc1 100644 --- a/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/tflite_lstm_cell_fusion.cc @@ -502,13 +502,12 @@ CNodePtr TfliteLstmCellFusion::CreateOutputGetItem(const FuncGraphPtr &func_grap return nullptr; } CNodePtr get_item_cnode = func_graph->NewCNode(tuple_get_item_prim, {node, get_item_value}); - std::vector shape_vector; - auto abstract_tensor = std::make_shared(kFloat32, shape_vector); - if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "create abstract_tensor failed"; + auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; return nullptr; } - get_item_cnode->set_abstract(abstract_tensor); + get_item_cnode->set_abstract(abstract); get_item_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_output_getitem_" + std::to_string(item_index)); return get_item_cnode; @@ -581,13 +580,12 @@ STATUS TfliteLstmCellFusion::SetAbstractTuple(const CNodePtr &cnode, const int o MS_ASSERT(cnode != nullptr); AbstractBasePtrList abstract_list; for (int i = 0; i < output_num; ++i) { - std::vector shape_vector; - auto abstract_tensor = std::make_shared(kFloat32, shape_vector); - if (abstract_tensor == nullptr) { - MS_LOG(ERROR) << "create abstract_tensor failed"; + auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; return RET_ERROR; } - abstract_list.emplace_back(abstract_tensor); + abstract_list.emplace_back(abstract); } auto abstract_tuple = std::make_shared(abstract_list); if (abstract_tuple == nullptr) { diff --git a/mindspore/lite/tools/optimizer/graph/functionalize_while.cc b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc index 9881170cdf0..2949075d0e9 100644 --- a/mindspore/lite/tools/optimizer/graph/functionalize_while.cc +++ b/mindspore/lite/tools/optimizer/graph/functionalize_while.cc @@ -23,6 +23,7 @@ #include "ops/return.h" #include "ops/tuple_get_item.h" #include "tools/converter/ops/while.h" +#include "tools/common/tensor_util.h" namespace { mindspore::ValueNodePtr GetWhileAnfPrim() { @@ -207,9 +208,13 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() { auto node_users = manager->node_users()[node]; for (auto &node_user : node_users) { // new getitem - AbstractBasePtrList abstractList; - std::vector shape_vector; - abstractList.emplace_back(std::make_shared(kFloat32, shape_vector)); + AbstractBasePtrList abstract_list; + auto abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + abstract_list.emplace_back(abstract); auto tuple_get_item_prim_ptr = std::make_shared(); if (tuple_get_item_prim_ptr == nullptr) { MS_LOG(ERROR) << "GetTupleGetItemPrim return nullptr"; @@ -225,12 +230,12 @@ STATUS FunctionalizeWhile::UpdateExitNodeUser() { std::vector inputs{tuple_get_item_prim, while_node_, getItemValue}; CNodePtr get_item_node = fg_->NewCNode(inputs); std::string output_item_name = while_node_->fullname_with_scope() + "_getitem_" + std::to_string(output_idx); - auto abstract = std::make_shared(kFloat32, shape_vector); - if (abstract == nullptr) { - MS_LOG(ERROR) << "create AbstractTensor failed"; - return RET_NULL_PTR; + auto get_item_node_abstract = lite::CreateTensorAbstract({}, kNumberTypeFloat32); + if (get_item_node_abstract == nullptr) { + MS_LOG(ERROR) << "Create get_item_node_abstract failed"; + return RET_ERROR; } - get_item_node->set_abstract(abstract); + get_item_node->set_abstract(get_item_node_abstract); get_item_node->set_fullname_with_scope(output_item_name); // set if (fg_->nodes().contains(node_user.first)) { diff --git a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc index f6d81837894..847fb4f9643 100644 --- a/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/group_depthwise_op_convert_pass.cc @@ -22,6 +22,7 @@ #include "src/tensor.h" #include "tools/converter/quantizer/quant_cast.h" #include "src/common/log_adapter.h" +#include "tools/common/tensor_util.h" #include "securec/include/securec.h" namespace mindspore::opt { @@ -101,13 +102,16 @@ bool GroupDepthwiseOpConvertPass::Run(const FuncGraphPtr &graph) { return false; } auto type_id = static_cast(weight_value->data_type()); - auto type_ptr = TypeIdToType(type_id); auto shape = weight_value->shape(); std::vector shape_vector; (void)std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), [](const int32_t &value) { return static_cast(value); }); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - weight_node->set_abstract(abstract_tensor); + auto abstract = lite::CreateTensorAbstract(shape_vector, type_id); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + weight_node->set_abstract(abstract); } } return true; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index 9d2e067bbce..1b4f1a52f8c 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -21,6 +21,7 @@ #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" #include "src/common/common.h" +#include "src/common/tensor_util.h" #include "src/ops/populate/populate_register.h" #include "src/ops/ops_utils.h" #include "src/runtime/infer_manager.h" @@ -28,19 +29,6 @@ namespace mindspore::opt { namespace { constexpr size_t INITIAL_SIZE = 1024; -tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { - std::vector shape(tensor->shape()); - std::vector shape_vector; - std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector), - [](const int32_t &value) { return static_cast(value); }); - auto tensor_info = std::make_shared(tensor->data_type(), shape_vector); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "new tensor::Tensor failed"; - return nullptr; - } - return tensor_info; -} - bool IsSpecialType(const CNodePtr &cnode) { if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimMakeTuple) || CheckPrimitiveType(cnode, prim::kPrimReturn) || @@ -75,21 +63,14 @@ STATUS GetTensorInfoFromAbstract(tensor::TensorPtr *tensor_info, const CNodePtr } } // namespace -abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor) { +abstract::AbstractBasePtr InferShapePass::ConvertLiteTensorToAbstract(lite::Tensor *tensor) { MS_ASSERT(nullptr != tensor); - std::vector shape(tensor->shape()); + auto shape = tensor->shape(); auto type_id = static_cast(tensor->data_type()); - auto type_ptr = TypeIdToType(type_id); std::vector shape_vector(shape.begin(), shape.end()); - auto new_abstract = std::make_shared(type_ptr, shape_vector); - if (new_abstract == nullptr) { - MS_LOG(ERROR) << "new AbstractTensor failed"; - return nullptr; - } - - auto tensor_info = NewTensorInfo(tensor); + auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, type_id); if (tensor_info == nullptr) { - MS_LOG(ERROR) << "new tensor::Tensor failed"; + MS_LOG(DEBUG) << "Create tensor info failed"; return nullptr; } @@ -112,8 +93,12 @@ abstract::AbstractTensorPtr InferShapePass::ConvertLiteTensorToAbstractTensor(li return nullptr; } } - new_abstract->set_value(tensor_info); - return new_abstract; + auto abstract = tensor_info->ToAbstract(); + if (abstract == nullptr) { + MS_LOG(DEBUG) << "Create tensor abstarct failed"; + return nullptr; + } + return abstract; } STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { @@ -143,8 +128,6 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { std::vector shape; (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape), [](const int64_t &value) { return static_cast(value); }); - - auto new_abstract = std::make_shared(type_ptr, shape_vector); auto new_tensor_info = std::make_shared(type_ptr->type_id(), shape_vector); if (parameter->has_default()) { auto old_tensor_info = std::dynamic_pointer_cast(parameter->default_param()); @@ -155,7 +138,11 @@ STATUS InferShapePass::SetParameterAbstract(const ParameterPtr ¶meter) { return RET_ERROR; } } - new_abstract->set_value(new_tensor_info); + auto new_abstract = new_tensor_info->ToAbstract(); + if (new_abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } parameter->set_abstract(new_abstract); return RET_OK; } @@ -304,7 +291,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector &outpu } if (output_tensors.size() == 1) { auto tensor = output_tensors.front(); - auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); + auto new_abstract = ConvertLiteTensorToAbstract(tensor); if (new_abstract == nullptr) { return RET_ERROR; } @@ -313,7 +300,7 @@ STATUS InferShapePass::SetCNodeAbstract(const std::vector &outpu AbstractBasePtrList abstract_list; for (size_t i = 0; i < output_tensors.size(); i++) { auto tensor = output_tensors.front(); - auto new_abstract = ConvertLiteTensorToAbstractTensor(tensor); + auto new_abstract = ConvertLiteTensorToAbstract(tensor); if (new_abstract == nullptr) { return RET_ERROR; } diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.h b/mindspore/lite/tools/optimizer/graph/infershape_pass.h index c753f0cecdf..4e727888071 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.h @@ -36,7 +36,7 @@ class InferShapePass : public Pass { private: void FreeTensors(std::vector *tensors); - abstract::AbstractTensorPtr ConvertLiteTensorToAbstractTensor(lite::Tensor *tensor); + abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor); STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector *input_tensors); STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *output_tensors); STATUS SetParameterAbstract(const ParameterPtr ¶meter); diff --git a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc index 210c686b35f..135035f02c8 100644 --- a/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/mindir_adjust_pass.cc @@ -179,23 +179,23 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) { if (!utils::isa(anf_node)) { return lite::RET_NO_CHANGE; } - auto valueNode = anf_node->cast(); - if (valueNode->abstract() == nullptr) { + auto value_node = anf_node->cast(); + if (value_node->abstract() == nullptr) { return lite::RET_NO_CHANGE; } - auto abstractTensor = utils::cast(valueNode->abstract()); - if (abstractTensor == nullptr) { + auto abstract_tensor = utils::cast(value_node->abstract()); + if (abstract_tensor == nullptr) { return lite::RET_NO_CHANGE; } - auto value = abstractTensor->GetValueTrack(); + auto value = abstract_tensor->GetValueTrack(); if (value != nullptr && value->isa()) { - if (abstractTensor->element() == nullptr) { + if (abstract_tensor->element() == nullptr) { MS_LOG(ERROR) << "abstractTensor->element() is nullptr."; return RET_ERROR; } - auto typePtr = abstractTensor->element()->GetTypeTrack(); - if (typePtr->type_id() == kNumberTypeInt64) { - auto shape_vector = utils::cast(abstractTensor->BuildShape())->shape(); + auto type_ptr = abstract_tensor->element()->GetTypeTrack(); + if (type_ptr->type_id() == kNumberTypeInt64) { + auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); auto dest_tensor_info = std::make_shared(kNumberTypeInt32, shape_vector); auto *dest_data_buf = reinterpret_cast(dest_tensor_info->data_c()); auto src_tensor_info = value->cast(); @@ -204,10 +204,10 @@ int MindirAdjustPass::ValueNodeInt64Convert(AnfNodePtr anf_node) { for (int i = 0; i < dest_tensor_info->ElementsNum(); i++) { dest_data_buf[i] = src_data_buf[i]; } - abstractTensor->set_value(dest_tensor_info); - abstractTensor->set_type(TypeIdToType(kNumberTypeInt32)); - abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32)); - valueNode->set_value(dest_tensor_info); + abstract_tensor->set_value(dest_tensor_info); + abstract_tensor->set_type(TypeIdToType(kNumberTypeInt32)); + abstract_tensor->element()->set_type(TypeIdToType(kNumberTypeInt32)); + value_node->set_value(dest_tensor_info); } } return lite::RET_NO_CHANGE; diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc index 904dffc8180..852a51a0d04 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc @@ -19,6 +19,7 @@ #include #include "ops/transpose.h" #include "tools/optimizer/common/gllo_utils.h" +#include "tools/common/tensor_util.h" using mindspore::lite::converter::FmkType_CAFFE; using mindspore::lite::converter::FmkType_MS; @@ -92,9 +93,20 @@ lite::STATUS WeightFormatTransformPass::TransposeInsertForWeightSharing(const Fu auto perm_node = BuildIntVecParameterNode(graph, perm, weight_node->fullname_with_scope() + "_perm"); auto prim = std::make_shared(); auto transpose_node = graph->NewCNode(prim, {weight_node, perm_node}); - auto type_ptr = TypeIdToType(kTypeUnknown); - std::vector shape_vector; - auto abstract = std::make_shared(type_ptr, shape_vector); + if (!weight_node->has_default()) { + MS_LOG(DEBUG) << "Weight parameter should has default parameter."; + return lite::RET_ERROR; + } + auto weight_tensor = weight_node->default_param()->cast(); + if (weight_tensor == nullptr) { + MS_LOG(DEBUG) << "Default parameter of weight parameter should be a tensor."; + return lite::RET_ERROR; + } + auto abstract = lite::CreateTensorAbstract(weight_tensor->shape_c(), weight_tensor->data_type()); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } transpose_node->set_abstract(abstract); transpose_node->set_fullname_with_scope(weight_node->fullname_with_scope() + "_post"); for (auto &adjust_node : adjust_nodes) { @@ -177,11 +189,14 @@ lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr return false; } auto type_id = static_cast(weight_value->data_type()); - auto type_ptr = TypeIdToType(type_id); auto shape = weight_value->shape(); std::vector shape_vector(shape.begin(), shape.end()); - auto abstract_tensor = std::make_shared(type_ptr, shape_vector); - weight_node->set_abstract(abstract_tensor); + auto abstract = lite::CreateTensorAbstract(shape_vector, type_id); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Create tensor abstarct failed"; + return RET_ERROR; + } + weight_node->set_abstract(abstract); } return RET_OK; }