From 59c6689512e5e2538310e9657f7342317093f019 Mon Sep 17 00:00:00 2001 From: yankai Date: Sun, 21 Jun 2020 09:48:20 +0800 Subject: [PATCH] read pb model --- include/ms_tensor.h | 2 +- mindspore/ccsrc/ir/anf.h | 1 + mindspore/ccsrc/session/CMakeLists.txt | 1 + .../ccsrc/session/ascend_inference_session.cc | 90 +++++++++++ .../ccsrc/session/ascend_inference_session.h | 45 ++++++ mindspore/ccsrc/session/session.cc | 5 +- mindspore/ccsrc/session/session_basic.cc | 1 + mindspore/ccsrc/utils/base_ref_utils.cc | 39 +++-- mindspore/ccsrc/utils/base_ref_utils.h | 5 +- mindspore/ccsrc/utils/context/ms_context.h | 1 + .../ccsrc/utils/load_onnx/anf_converter.cc | 28 ---- .../ccsrc/utils/load_onnx/anf_converter.h | 1 - .../ccsrc/utils/load_onnx/anf_model_parser.cc | 142 ++++++++++++------ .../ccsrc/utils/load_onnx/anf_model_parser.h | 17 +-- 14 files changed, 270 insertions(+), 108 deletions(-) create mode 100644 mindspore/ccsrc/session/ascend_inference_session.cc create mode 100644 mindspore/ccsrc/session/ascend_inference_session.h diff --git a/include/ms_tensor.h b/include/ms_tensor.h index 2e715aa7733..1f9661df5e2 100644 --- a/include/ms_tensor.h +++ b/include/ms_tensor.h @@ -63,7 +63,7 @@ class MS_API MSTensor { // return A pointer points to data in MSTensor. virtual void *MutableData() const = 0; }; -using MultiTensor = std::vector>>; +using MultiTensor = std::vector>; } // namespace inference } // namespace mindspore #endif // MINDSPORE_INCLUDE_MS_TENSOR_H_ diff --git a/mindspore/ccsrc/ir/anf.h b/mindspore/ccsrc/ir/anf.h index c2db17aec5a..d663d9a9d13 100644 --- a/mindspore/ccsrc/ir/anf.h +++ b/mindspore/ccsrc/ir/anf.h @@ -216,6 +216,7 @@ class CNode : public AnfNode { void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; } std::string fullname_with_scope() override; + void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; } std::string DebugString(int recursive_level = 1) const override; std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } diff --git a/mindspore/ccsrc/session/CMakeLists.txt b/mindspore/ccsrc/session/CMakeLists.txt index 2824af8a5d1..782eb511837 100644 --- a/mindspore/ccsrc/session/CMakeLists.txt +++ b/mindspore/ccsrc/session/CMakeLists.txt @@ -23,6 +23,7 @@ if (ENABLE_D) file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend_session.cc" "ascend_control_parser.cc" + "ascend_inference_session.cc" ) list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) endif () diff --git a/mindspore/ccsrc/session/ascend_inference_session.cc b/mindspore/ccsrc/session/ascend_inference_session.cc new file mode 100644 index 00000000000..ff538745028 --- /dev/null +++ b/mindspore/ccsrc/session/ascend_inference_session.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "session/ascend_inference_session.h" +#include "operator/ops.h" +#include "ir/tensor.h" +#include "ir/anf.h" +#include "ir/param_value_py.h" +#include "device/kernel_runtime.h" +#include "session/anf_runtime_algorithm.h" +#include "common/utils.h" +#include "common/trans.h" +#include "kernel/tbe/tbe_python_funcs.h" +#include "utils/config_manager.h" +#include "utils/base_ref_extends.h" + +namespace mindspore { +namespace session { +void AscendInferenceSession::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector inputs(inputs_const); + auto input_nodes = kernel_graph->inputs(); + + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + size_t no_weight_input = 0; + for (size_t i = 0; i < input_nodes.size(); ++i) { + tensor::TensorPtr tensor = nullptr; + if (!input_nodes[i]->isa()) { + MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; + continue; + } + auto pk_node = input_nodes[i]->cast(); + MS_EXCEPTION_IF_NULL(pk_node); + if (AnfAlgo::IsParameterWeight(pk_node)) { + auto param_value = std::dynamic_pointer_cast(pk_node->default_param()); + MS_EXCEPTION_IF_NULL(param_value); + auto py_param = param_value->value(); + MS_EXCEPTION_IF_NULL(py_param); + py::array py_array = py_param.cast(); + tensor = std::make_shared(py_array); + } else { + tensor = inputs[no_weight_input++]; + } + MS_EXCEPTION_IF_NULL(tensor); + if (AnfAlgo::OutputAddrExist(pk_node, 0)) { + auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); + bool need_sync = false; + if (ms_context->enable_pynative_infer()) { + if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { + need_sync = true; + } + } else { + if (tensor->is_dirty()) { + need_sync = true; + } else if (tensor->device_address() != device_address) { + (void)tensor->data_sync(); + need_sync = true; + } + } + if (need_sync) { + if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { + tensor->set_device_address(device_address); + } + MS_EXCEPTION_IF_NULL(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c(false))) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; + } + } + } + tensor->set_dirty(false); + } +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/session/ascend_inference_session.h b/mindspore/ccsrc/session/ascend_inference_session.h new file mode 100644 index 00000000000..53be881f93d --- /dev/null +++ b/mindspore/ccsrc/session/ascend_inference_session.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#define MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "session/ascend_session.h" +#include "session/kernel_graph.h" +#include "kernel/kernel.h" +#include "session/session_factory.h" +#include "session/ascend_control_parser.h" + +namespace mindspore { +namespace session { +class AscendInferenceSession : public AscendSession { + public: + AscendInferenceSession() = default; + ~AscendInferenceSession() = default; + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const; +}; +MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); +} // namespace session +} // namespace mindspore +#endif // MINDSPORE_CCSRC_SESSION_ASCEND_INFERENCE_SESSION_H diff --git a/mindspore/ccsrc/session/session.cc b/mindspore/ccsrc/session/session.cc index f70ff316da0..bd883dc8419 100644 --- a/mindspore/ccsrc/session/session.cc +++ b/mindspore/ccsrc/session/session.cc @@ -124,7 +124,7 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vector>> multiTensor; + std::vector> multiTensor; return multiTensor; } VectorRef outputs; @@ -135,6 +135,9 @@ MultiTensor Session::RunGraph(uint32_t graph_id, const std::vectorset_execution_mode(kGraphMode); + ms_context->set_device_target(kAscendDevice); session_impl_ = session::SessionFactory::Get().Create(device); if (session_impl_ == nullptr) { MS_LOG(ERROR) << "Session create failed!, please make sure target device:" << device << " is available."; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index e5e58045cdd..bd6c8fbc92a 100644 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -610,6 +610,7 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP auto new_cnode = CreateNewCNode(cnode, graph.get()); MS_EXCEPTION_IF_NULL(new_cnode); new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_fullname_with_scope(cnode->fullname_with_scope()); new_cnode->set_scope(cnode->scope()); graph->FrontBackendlMapAdd(node, new_cnode); if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimReturn)) { diff --git a/mindspore/ccsrc/utils/base_ref_utils.cc b/mindspore/ccsrc/utils/base_ref_utils.cc index 617057b866b..87089c62667 100644 --- a/mindspore/ccsrc/utils/base_ref_utils.cc +++ b/mindspore/ccsrc/utils/base_ref_utils.cc @@ -21,20 +21,27 @@ #include "ir/tensor.h" namespace mindspore { -std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref) { +void IterateFindTensor(std::vector> *msTensors, const VectorRef &ref_list) { + for (size_t i = 0; i < ref_list.size(); ++i) { + if (utils::isa(ref_list[i])) { + auto tensor_ptr = utils::cast>(ref_list[i]); + MS_EXCEPTION_IF_NULL(tensor_ptr); + auto tensor = new inference::Tensor(tensor_ptr); + msTensors->emplace_back(std::shared_ptr(tensor)); + } else if (utils::isa(ref_list[i])) { + auto ref_iter = utils::cast(ref_list[i]); + IterateFindTensor(msTensors, ref_iter); + } else { + MS_LOG(EXCEPTION) << "The output is not a tensor"; + } + } +} + +std::vector> TransformVectorRefToMultiTensor(const VectorRef &base_ref) { std::vector> msTensors; if (utils::isa(base_ref)) { auto ref_list = utils::cast(base_ref); - for (size_t i = 0; i < ref_list.size(); ++i) { - if (utils::isa(ref_list[i])) { - auto tensor_ptr = utils::cast>(ref_list[i]); - MS_EXCEPTION_IF_NULL(tensor_ptr); - auto tensor = new inference::Tensor(tensor_ptr); - msTensors.emplace_back(std::shared_ptr(tensor)); - } else { - MS_LOG(EXCEPTION) << "The output is not a tensor!"; - } - } + IterateFindTensor(&msTensors, ref_list); } else if (utils::isa(base_ref)) { auto tensor_ptr = utils::cast>(base_ref); MS_EXCEPTION_IF_NULL(tensor_ptr); @@ -45,14 +52,4 @@ std::vector> TransformBaseRefToMSTensor(con } return msTensors; } - -std::vector>> TransformVectorRefToMultiTensor( - const VectorRef &vector_ref) { - std::vector>> multiTensor; - for (size_t i = 0; i < vector_ref.size(); ++i) { - auto tensors = TransformBaseRefToMSTensor(vector_ref[i]); - multiTensor.emplace_back(tensors); - } - return multiTensor; -} } // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref_utils.h b/mindspore/ccsrc/utils/base_ref_utils.h index 787918c7240..2503eab7388 100644 --- a/mindspore/ccsrc/utils/base_ref_utils.h +++ b/mindspore/ccsrc/utils/base_ref_utils.h @@ -22,9 +22,6 @@ #ifndef MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H #define MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H namespace mindspore { -std::vector> TransformBaseRefToMSTensor(const BaseRef &base_ref); - -std::vector>> TransformVectorRefToMultiTensor( - const VectorRef &vector_ref); +std::vector> TransformVectorRefToMultiTensor(const VectorRef &base_ref); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_BASE_REF_UTILS_H diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index a1ab728bc74..7aaa4e47503 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -41,6 +41,7 @@ const int kPynativeMode = 1; const char kCPUDevice[] = "CPU"; const char kGPUDevice[] = "GPU"; const char kAscendDevice[] = "Ascend"; +const char kDavinciInferenceDevice[] = "AscendInference"; const char kDavinciDevice[] = "Davinci"; const char KNpuLog[] = "_npu_log"; const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; diff --git a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc b/mindspore/ccsrc/utils/load_onnx/anf_converter.cc index f46da657cce..ad87d6ae8fb 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_converter.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_converter.cc @@ -96,8 +96,6 @@ std::shared_ptr AnfConverter::RunAnfConverter(const std::string &file ReadOnnxFromBinary(modelFile, &model_); MSANFModelParser model_parser; FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); - MS_EXCEPTION_IF_NULL(dstgraph_ptr); - TestFuncGraphBuild(dstgraph_ptr); return dstgraph_ptr; } @@ -111,33 +109,7 @@ std::shared_ptr AnfConverter::RunAnfConverter(const char *buf, const } MSANFModelParser model_parser; FuncGraphPtr dstgraph_ptr = model_parser.Parse(model_); - MS_EXCEPTION_IF_NULL(dstgraph_ptr); - TestFuncGraphBuild(dstgraph_ptr); return dstgraph_ptr; } - -int AnfConverter::TestFuncGraphBuild(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto node_return = graph->get_return(); - std::vector node_list = TopoSort(node_return); - MS_LOG(INFO) << "node_list size is : " << node_list.size(); - for (auto &node : node_list) { - if (node->isa()) { - auto node_CN = node->cast(); - MS_LOG(INFO) << "CN node: " << node_CN->input(0)->ToString() << ", input size :" << node_CN->size(); - } else if (node->isa()) { - auto node_Para = node->cast(); - if (node_Para->has_default()) { - MS_LOG(INFO) << "Parameter node: " << node_Para->name() << "has default value!"; - } else { - MS_LOG(INFO) << "Parameter node: " << node_Para->name(); - } - } else if (node->isa()) { - auto node_Value = node->cast(); - MS_LOG(INFO) << "Value node: " << node_Value->ToString(); - } - } - return 0; -} } // namespace lite } // namespace mindspore diff --git a/mindspore/ccsrc/utils/load_onnx/anf_converter.h b/mindspore/ccsrc/utils/load_onnx/anf_converter.h index 2c820053ee2..4f5fe3971fd 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_converter.h +++ b/mindspore/ccsrc/utils/load_onnx/anf_converter.h @@ -26,7 +26,6 @@ namespace mindspore { namespace lite { class AnfConverter { public: - static int TestFuncGraphBuild(const FuncGraphPtr &graph); static std::shared_ptr RunAnfConverter(const std::string &file_path); static std::shared_ptr RunAnfConverter(const char *buf, const size_t buf_size); diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc index d624bc51c88..e44eb230017 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc @@ -14,16 +14,17 @@ * limitations under the License. */ +#include "utils/load_onnx/anf_model_parser.h" #include #include #include #include #include -#include "utils/load_onnx/anf_model_parser.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "ir/tensor.h" #include "ir/param_value_py.h" #include "operator/ops.h" +#include "pipeline/static_analysis/abstract_value.h" #include "proto/onnx.pb.h" #include "utils/log_adapter.h" @@ -33,6 +34,8 @@ namespace mindspore { namespace lite { static constexpr char kConstantValueNode[] = "Constant"; static constexpr char kCNodeShapeAttr[] = "shape"; +static constexpr char kCNodeShape1Attr[] = "shape1"; +static constexpr char kCNodeShape2Attr[] = "shape2"; enum ParseForm : int { FORM_PARSE_TYPE = 0, FORM_PARSE_SCALAR = 1, @@ -56,14 +59,15 @@ static std::unordered_map kDefaultValueSwitchMap{ void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ const onnx::TensorProto &attr_tensor) { \ MS_EXCEPTION_IF_NULL(prim); \ - std::vector attr_value_vec; \ + std::vector attr_value_vec; \ for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ - attr_value_vec.push_back(static_cast(attr_tensor.type##_data(i))); \ + auto value = static_cast(attr_tensor.type##_data(i)); \ + attr_value_vec.push_back(MakeValue(value)); \ } \ if (attr_value_vec.size() == 1) { \ - prim->AddAttr(attr_name, MakeValue(attr_value_vec[0])); \ + prim->AddAttr(attr_name, attr_value_vec[0]); \ } else { \ - prim->AddAttr(attr_name, MakeValue>(attr_value_vec)); \ + prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ } \ } @@ -247,17 +251,12 @@ bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node const std::string &tensor_buf = attr_tensor.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->data_c(true)); memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), tensor_buf.data(), tensor_buf.size()); - if (attr_tensor_type == onnx::TensorProto_DataType_FLOAT) { - auto *data_valuennode = reinterpret_cast(tensor_info->data_c()); - MS_EXCEPTION_IF_NULL(data_valuennode); - auto new_value_node = std::make_shared(MakeValue(*data_valuennode)); - anfnode_build_map_[value_node_name] = new_value_node; - } else { - auto *data_valuenode = reinterpret_cast(tensor_info->data_c()); - MS_EXCEPTION_IF_NULL(data_valuenode); - auto new_value_node = std::make_shared(MakeValue(*data_valuenode)); - anfnode_build_map_[value_node_name] = new_value_node; - } + auto new_value_node = NewValueNode(MakeValue(tensor_info)); + MS_EXCEPTION_IF_NULL(new_value_node); + auto tensor_abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(tensor_abstract); + new_value_node->set_abstract(tensor_abstract); + anfnode_build_map_[value_node_name] = new_value_node; return true; } @@ -315,7 +314,9 @@ bool MSANFModelParser::ObtainValueNodeInTypeForm(const std::string &value_node_n MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; return false; } - auto new_value_node = std::make_shared(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); + new_value_node->set_abstract(abs_type); anfnode_build_map_[value_node_name] = new_value_node; return true; } @@ -361,31 +362,45 @@ AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor.data_type()], shape_vec); MS_EXCEPTION_IF_NULL(tensor_info); - return tensor_info->ToAbstract(); + auto abstract = tensor_info->ToAbstract(); + MS_EXCEPTION_IF_NULL(abstract); + return abstract; } -bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, - const onnx::GraphProto &importProto, const bool &ret_flag) { +CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); if (!node_proto.has_op_type()) { MS_LOG(ERROR) << "Get CNode op_type failed!"; - return false; + return nullptr; } const std::string &node_name = node_proto.output(0); + const std::string &fullname_with_scope = node_proto.domain(); const std::string &node_type = node_proto.op_type(); PrimitivePtr prim = std::make_shared(node_type); MS_EXCEPTION_IF_NULL(prim); + prim->set_instance_name(node_type); - AbstractBasePtr abstract; + AbstractBasePtr abstract = nullptr; + AbstractBasePtr abstract_first = nullptr; + AbstractBasePtr abstract_second = nullptr; for (int i = 0; i < node_proto.attribute_size(); ++i) { const onnx::AttributeProto &attr_proto = node_proto.attribute(i); if (attr_proto.name() == kCNodeShapeAttr) { abstract = GetAbstractForCNode(attr_proto); continue; } + if (attr_proto.name() == kCNodeShape1Attr) { + abstract_first = GetAbstractForCNode(attr_proto); + continue; + } + if (attr_proto.name() == kCNodeShape2Attr) { + abstract_second = GetAbstractForCNode(attr_proto); + continue; + } if (!GetAttrValueForCNode(prim, attr_proto)) { MS_LOG(ERROR) << "Get CNode attr failed!"; - return false; + return nullptr; } } @@ -396,16 +411,64 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap const std::string &input_name = node_proto.input(i); if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; - return false; + return nullptr; } inputs.push_back(anfnode_build_map_[input_name]); } CNodePtr cnode_ptr = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); - cnode_ptr->set_abstract(abstract); - if (ret_flag) { + if (node_type == "LayerNorm") { + AbstractBasePtrList elem; + elem.push_back(abstract); + elem.push_back(abstract_first); + elem.push_back(abstract_second); + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (node_type == "ArgMaxWithValue") { + AbstractBasePtrList elem; + elem.push_back(abstract); + elem.push_back(abstract_first); + cnode_ptr->set_abstract(std::make_shared(elem)); + } else if (nullptr == abstract) { + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + elem.push_back(cnode_ptr->input(index)->abstract()); + } + cnode_ptr->set_abstract(std::make_shared(elem)); + } else { + cnode_ptr->set_abstract(abstract); + } + cnode_ptr->set_fullname_with_scope(fullname_with_scope); + anfnode_build_map_[node_name] = cnode_ptr; + return cnode_ptr; +} + +bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr) { + MS_EXCEPTION_IF_NULL(outputFuncGraph); + MS_EXCEPTION_IF_NULL(cnode_ptr); + std::vector inputs; + if (importProto.output_size() > 1) { + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + AbstractBasePtrList elem; + for (int out_size = 0; out_size < importProto.output_size(); ++out_size) { + const onnx::ValueInfoProto &output_node = importProto.output(out_size); + const std::string &out_tuple = output_node.name(); + inputs.push_back(anfnode_build_map_[out_tuple]); + elem.push_back(anfnode_build_map_[out_tuple]->abstract()); + } + auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); + maketuple_ptr->set_abstract(std::make_shared(elem)); + inputs.clear(); + inputs.push_back(NewValueNode(prim::kPrimReturn)); + inputs.push_back(maketuple_ptr); + auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); + outputFuncGraph->set_return(return_node); + MS_LOG(INFO) << "Construct funcgraph finined, all success."; + } else { const onnx::ValueInfoProto &output_node = importProto.output(0); - const ::onnx::TypeProto &output_typeproto = output_node.type(); + const onnx::TypeProto &output_typeproto = output_node.type(); int output_type = output_typeproto.tensor_type().elem_type(); std::vector output_shape; for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { @@ -417,20 +480,19 @@ bool MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGrap inputs.push_back(NewValueNode(prim::kPrimReturn)); inputs.push_back(cnode_ptr); auto return_node = outputFuncGraph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(return_node); return_node->set_abstract(tensor_return->ToAbstract()); outputFuncGraph->set_return(return_node); MS_LOG(INFO) << "Construct funcgraph finined, all success!"; } - anfnode_build_map_[node_name] = cnode_ptr; return true; } bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); - bool return_flag = false; MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); + CNodePtr cnode_ptr = nullptr; for (int i = 0; i < importProto.node_size(); ++i) { - return_flag = (i == importProto.node_size() - 1) ? true : return_flag; const onnx::NodeProto &node_proto = importProto.node(i); const std::string &node_type = node_proto.op_type(); if (node_type == kConstantValueNode) { @@ -440,11 +502,14 @@ bool MSANFModelParser::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, } continue; } - if (!BuildCNodeForFuncGraph(outputFuncGraph, node_proto, importProto, return_flag)) { + cnode_ptr = BuildCNodeForFuncGraph(outputFuncGraph, node_proto); + if (cnode_ptr == nullptr) { MS_LOG(ERROR) << "Build CNode for funcgraph fail at index: : " << i; return false; } } + + BuildReturnForFuncGraph(outputFuncGraph, importProto, cnode_ptr); return true; } @@ -472,12 +537,12 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode producer_name_ = model_proto.producer_name(); MS_LOG(INFO) << "producer_name :" << producer_name_; - if (!model_proto.has_producer_version()) { + if (!model_proto.has_model_version()) { MS_LOG(ERROR) << "Parse model producer version from pb file failed!"; return false; } - producer_version_ = model_proto.producer_version(); - MS_LOG(INFO) << "producer_version : " << producer_version_; + model_version_ = model_proto.model_version(); + MS_LOG(INFO) << "producer_version : " << model_version_; if (!model_proto.has_ir_version()) { MS_LOG(ERROR) << "Parse model version from pb file failed!"; @@ -485,14 +550,6 @@ bool MSANFModelParser::MSANFParseModelConfigureInfo(const onnx::ModelProto &mode } ir_version_ = model_proto.ir_version(); MS_LOG(INFO) << "ir_version :" << ir_version_; - - const onnx::OperatorSetIdProto &opset_proto = model_proto.opset_import(0); - if (!opset_proto.has_version()) { - MS_LOG(ERROR) << "Parse opset version from pb file failed!"; - return false; - } - opset_version_ = opset_proto.version(); - MS_LOG(INFO) << "opset_version : " << opset_version_; return true; } @@ -501,7 +558,6 @@ FuncGraphPtr MSANFModelParser::Parse(const onnx::ModelProto &model_proto) { MS_EXCEPTION_IF_NULL(dstGraph); if (!MSANFParseModelConfigureInfo(model_proto)) { MS_LOG(ERROR) << "Parse configuration info for pb file failed!"; - return nullptr; } const onnx::GraphProto &graphBuild = model_proto.graph(); if (!BuildFuncGraph(dstGraph, graphBuild)) { diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h index 20787cbef49..11b9cd101f8 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.h @@ -29,6 +29,7 @@ namespace lite { using int32 = int32_t; using int64 = int64_t; using uint64 = uint64_t; +using float16 = Eigen::half; class MSANFModelParser { public: MSANFModelParser() = default; @@ -38,17 +39,17 @@ class MSANFModelParser { bool MSANFParseModelConfigureInfo(const onnx::ModelProto &model_proto); std::string GetProducerName() { return producer_name_; } - std::string GetProducerVersion() { return producer_version_; } + int GetProducerVersion() { return model_version_; } int GetIrVersion() { return ir_version_; } - int GetOpsetVersion() { return opset_version_; } private: bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); - bool BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, - const onnx::GraphProto &importProto, const bool &ret_flag); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto); + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const CNodePtr &cnode_ptr); bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); @@ -63,15 +64,13 @@ class MSANFModelParser { bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, const onnx::TensorProto &attr_tensor); bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); std::string producer_name_; - std::string producer_version_; - int ir_version_{}; - int opset_version_{}; + int model_version_; + int ir_version_; std::unordered_map anfnode_build_map_; std::map default_para_map_; - - AbstractBasePtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); }; } // namespace lite } // namespace mindspore