diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 74da81231fc..3fd33f8d9f1 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -55,12 +55,13 @@ constexpr int kMaxDepth = 2048; std::list GetOrderedCNodes(const FuncGraphPtr fg) { auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1); auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector { - std::vector vecs; + std::vector vecs{}; if (node == nullptr) { return vecs; } if (node->isa()) { auto cnode = node->cast(); + MS_ASSERT(cnode != nullptr); auto &inputs = cnode->inputs(); // Check if free variables used. for (const auto &input : inputs) { @@ -78,7 +79,7 @@ std::list GetOrderedCNodes(const FuncGraphPtr fg) { return vecs; }; - std::list cnodes; + std::list cnodes{}; auto nodes = TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph); for (const auto &node : nodes) { auto cnode = dyn_cast(node); @@ -100,10 +101,7 @@ int AnfExporter::SetPostTrainOutputTensorType(const std::unique_ptrdataType = kNumberTypeInt8; } else { auto primc = primitive->cast>(); - if (primc == nullptr) { - MS_LOG(ERROR) << "primitive is nullptr."; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(primc != nullptr, RET_ERROR, "cast ptr failed"); if (primc->get_dst_t() != kNumberTypeFloat32) { first_tensor_output->dataType = kNumberTypeInt8; } @@ -220,6 +218,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptrcast(); + MS_CHECK_TRUE_MSG(param_node != nullptr, RET_NULL_PTR, "cast ptr failed"); if (FetchFromDefaultParam(param_node, converter::FmkType(meta_graphT->fmkType), &data_info) != RET_OK) { MS_LOG(ERROR) << "FetchFromDefaultParam failed."; return RET_ERROR; @@ -242,7 +241,7 @@ int AnfExporter::CreateNewTensorForParameter(const std::unique_ptr &meta_graphT, const size_t &subgraph_index) { auto &subgraph = meta_graphT->subGraph.at(subgraph_index); - FuncGraphPtr fg; + FuncGraphPtr fg = nullptr; std::for_each(fg_subgraph_map_.begin(), fg_subgraph_map_.end(), [&subgraph_index, &fg](const std::pair &it) { if (it.second == subgraph_index) { @@ -317,6 +316,7 @@ int AnfExporter::ExportPartialNode(const std::unique_ptr &me const bool ©_primitive, const CNodePtr &partial_cnode, const std::unique_ptr &schema_cnode) { auto prim = GetValueNode>(partial_cnode->input(0)); + MS_CHECK_TRUE_MSG(prim != nullptr, RET_NULL_PTR, "GetValueNode failed"); if (prim->name() != mindspore::ops::kNamePartialFusion) { MS_LOG(INFO) << "not is partial"; return RET_OK; @@ -324,13 +324,10 @@ int AnfExporter::ExportPartialNode(const std::unique_ptr &me auto partial_fusion_primc = schema_cnode->primitive->value.AsPartialFusion(); auto vnode = partial_cnode->input(kFirstDataIndex)->cast(); - MS_ASSERT(vnode != nullptr); + MS_CHECK_TRUE_MSG(partial_fusion_primc != nullptr, RET_NULL_PTR, "partial_fusion_primc is invalid"); + MS_CHECK_TRUE_MSG(vnode != nullptr, RET_NULL_PTR, "vnode is invalid"); auto fg = vnode->value()->cast(); - if (fg == nullptr) { - MS_LOG(ERROR) << "func graph is nullptr."; - return RET_NULL_PTR; - } - + MS_CHECK_TRUE_MSG(fg != nullptr, RET_NULL_PTR, "func graph is nullptr."); if (fg_subgraph_map_.find(fg) != fg_subgraph_map_.end()) { partial_fusion_primc->sub_graph_index = fg_subgraph_map_.at(fg); return RET_OK; @@ -376,10 +373,6 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptr>(cnode->input(kPrimIndex)); std::unique_ptr primT; - if (prim == nullptr) { - MS_LOG(ERROR) << "prim is nullptr."; - return RET_ERROR; - } ret = RemoveIfDepend(cnode); if (ret != RET_OK) { @@ -513,10 +506,13 @@ FuncGraphPtr GetFinalGraph(const FuncGraphPtr &func_graph, int i) { auto cnode = call_cnode->input(kFirstDataIndex)->cast(); if (opt::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { auto false_cnode = cnode->input(kSwitchFalseIndex)->cast(); + MS_CHECK_TRUE_MSG(false_cnode != nullptr, nullptr, "cast failed"); auto false_fg = GetValueNode(false_cnode->input(kFirstDataIndex)); + MS_CHECK_TRUE_MSG(false_fg != nullptr, nullptr, "GetValueNode failed"); return GetFinalGraph(false_fg, i); } else { auto fg = GetValueNode(cnode->input(kFirstDataIndex)); + MS_CHECK_TRUE_MSG(fg != nullptr, nullptr, "GetValueNode failed"); return GetFinalGraph(fg, i); } @@ -542,10 +538,7 @@ int AnfExporter::SetMetaGraphOutput(const FuncGraphPtr &func_graph, const std::unique_ptr &meta_graphT) { int i = 0; auto final_fg = GetFinalGraph(func_graph, i); - if (final_fg == nullptr) { - MS_LOG(ERROR) << "GetFinalGraph failed."; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(final_fg != nullptr, RET_ERROR, "GetFinalGraph failed."); auto final_meta_graph_index = fg_subgraph_map_.at(final_fg); auto &final_meta_graph = meta_graphT->subGraph.at(final_meta_graph_index); meta_graphT->outputIndex.assign(final_meta_graph->outputIndices.begin(), final_meta_graph->outputIndices.end()); @@ -605,10 +598,7 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema } if (utils::isa(input_anode->abstract())) { auto tuple = std::reinterpret_pointer_cast(input_anode->abstract()); - if (tuple == nullptr) { - MS_LOG(ERROR) << "tuple is nullptr"; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr"); auto elements = tuple->elements(); for (size_t i = 0; i < elements.size(); i++) { auto key = std::make_pair(input_anode, i); @@ -627,6 +617,7 @@ int AnfExporter::ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, schema::CNodeT *output_cnode) { auto input_cnode = utils::cast(input_anode); + MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "cast ptr failed"); auto input_value_node = input_cnode->input(kPrimIndex)->cast(); if (input_value_node == nullptr) { if (!IsCall(input_cnode)) { @@ -635,6 +626,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, } else { auto call_anf_prim_vnode = GetCallAnfPrim(); auto cnode_input = input_cnode->inputs(); + MS_CHECK_TRUE_MSG(call_anf_prim_vnode != nullptr, RET_ERROR, "GetCallAnfPrim failed"); cnode_input.insert(cnode_input.begin(), call_anf_prim_vnode); input_cnode->set_inputs(cnode_input); } @@ -658,10 +650,7 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, return RET_ERROR; } auto value_node = utils::cast(index_vnode); - if (value_node == nullptr) { - MS_LOG(ERROR) << "cast to ValueNode failed"; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "cast to ValueNode failed"); auto idx = value_node->value()->type()->number_type() == kNumberTypeInt64 ? GetValue(value_node->value()) : GetValue(value_node->value()); auto key = std::make_pair(get_item_input_cnode, idx); @@ -794,10 +783,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr(cnode->abstract())) { auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); - if (tuple == nullptr) { - MS_LOG(ERROR) << "tuple is nullptr"; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(tuple != nullptr, RET_ERROR, "tuple is nullptr"); auto elements = tuple->elements(); for (size_t i = 0; i < lite::GetCNodeOutputsSize(cnode, train_flag_); i++) { auto ms_tensor = new (std::nothrow) schema::TensorT(); @@ -841,6 +827,7 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr(elements[i]); auto type_ptr = abstract_tensor->element()->GetTypeTrack(); + MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr"); ms_tensor->dataType = type_ptr->type_id(); meta_graphT->allTensors.emplace_back(ms_tensor); if (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || @@ -881,8 +868,10 @@ int AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr inputs{call_anf_prim_vnode, node}; auto cnode = fg->NewCNodeInOrder(inputs); + MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "NewCNode failed"); cnode->set_func_graph(fg); return cnode; } @@ -890,19 +879,23 @@ CNodePtr AnfExporter::CreateCallCnode(const FuncGraphPtr &fg, const AnfNodePtr & CNodePtr AnfExporter::CreatePartialCnode(const FuncGraphPtr &fg, const AnfNodePtr &node) { if (utils::isa(node)) { auto cnode = utils::cast(node); + MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cast ptr failed"); auto primitive_c = GetValueNode>(cnode->input(kPrimIndex)); if (primitive_c != nullptr) { return cnode; } auto partial_anf_prim_vnode = GetPartialFusionPrim(); auto cnode_input = cnode->inputs(); + MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed"); cnode_input.insert(cnode_input.begin(), partial_anf_prim_vnode); cnode->set_inputs(cnode_input); return cnode; } else if (utils::isa(node)) { auto partial_anf_prim_vnode = GetPartialFusionPrim(); + MS_CHECK_TRUE_MSG(partial_anf_prim_vnode != nullptr, nullptr, "GetPartialFusionPrim failed"); std::vector inputs{partial_anf_prim_vnode, node}; auto cnode = fg->NewCNode(inputs); + MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "New cnode failed"); return cnode; } else { MS_LOG(ERROR) << "failed to create partial cnode."; diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.cc b/mindspore/lite/tools/anf_exporter/fetch_content.cc index 8ea4112b5b4..e56f99e2797 100644 --- a/mindspore/lite/tools/anf_exporter/fetch_content.cc +++ b/mindspore/lite/tools/anf_exporter/fetch_content.cc @@ -91,6 +91,7 @@ STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, Sh } auto abstract_tensor = utils::cast(abstract_base); auto typePtr = abstract_tensor->element()->GetTypeTrack(); + MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed"); MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr"); *data_type = typePtr->type_id(); if (!utils::isa(abstract_tensor->BuildShape())) { @@ -105,12 +106,14 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri bool train_flag, DataInfo *data_info) { MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); auto valueAbstract = value_node->abstract(); + MS_CHECK_TRUE_MSG(valueAbstract != nullptr, RET_ERROR, "valueAbstract is nullptr"); auto abstract_tensor = utils::cast(valueAbstract); if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; return RET_ERROR; } auto typePtr = abstract_tensor->element()->GetTypeTrack(); + MS_CHECK_TRUE_MSG(typePtr != nullptr, RET_ERROR, "typePtr is nullptr"); data_info->data_type_ = typePtr->type_id(); auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); std::vector dims(shape_vector.begin(), shape_vector.end()); @@ -121,6 +124,7 @@ int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &pri auto value = value_node->value(); MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr"); auto data = value->cast(); + MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is invalid"); data_info->data_.resize(data->Size()); if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; @@ -159,6 +163,7 @@ int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &pr auto value = value_node->value(); MS_CHECK_TRUE_MSG(value != nullptr, RET_ERROR, "value is nullptr"); auto data = value->cast(); + MS_CHECK_TRUE_MSG(data != nullptr, RET_ERROR, "data is nullptr"); auto data_value = data->value(); if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) { MS_LOG(ERROR) << "memcpy_s failed"; @@ -173,6 +178,7 @@ int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &pri data_info->shape_ = {1}; data_info->data_.resize(sizeof(int)); auto data = value_node->value()->cast(); + MS_CHECK_TRUE_MSG(data != nullptr, RET_NULL_PTR, "cast NumberPtr failed"); int number_type = data->number_type(); if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) { number_type = TypeToTypeMap.at(number_type); @@ -256,16 +262,13 @@ int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::F DataInfo *data_info) { MS_ASSERT(cnode != nullptr && data_info != nullptr); auto param_node = cnode->input(index)->cast(); - if (param_node == nullptr) { - MS_LOG(ERROR) << "input node is not parameter node."; - return RET_ERROR; - } - + MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "input node is not parameter node."); if (FetchFromDefaultParam(param_node, fmk_type, data_info) != RET_OK) { MS_LOG(ERROR) << "fetch information from default param failed."; return RET_ERROR; } auto prim = GetValueNode(cnode->input(0)); + MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed"); if (prim->GetAttr(ops::kFormat) == nullptr && !param_node->has_default()) { data_info->format_ = mindspore::NHWC; } @@ -289,10 +292,8 @@ int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkTy DataInfo *data_info) { MS_ASSERT(cnode != nullptr && data_info != nullptr); auto value_node = cnode->input(index)->cast(); - if (value_node == nullptr) { - MS_LOG(ERROR) << "input node is not value node."; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(value_node != nullptr, RET_ERROR, "input node is not value node."); + auto value = value_node->value(); int ret = RET_OK; auto prim = GetValueNode(cnode->input(0)); @@ -328,6 +329,7 @@ int SetFormatForCnode(const CNodePtr &cnode, size_t index, converter::FmkType fm DataInfo *data_info) { data_info->format_ = mindspore::NHWC; auto input_node_prim = GetValueNode((cnode->input(index)->cast()->input(0))); + MS_CHECK_TRUE_MSG(input_node_prim != nullptr, RET_ERROR, "GetValueNode failed"); if (input_node_prim->GetAttr(ops::kFormat) != nullptr) { auto value = input_node_prim->GetAttr(ops::kFormat); if (value->isa()) { @@ -367,6 +369,7 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f return RET_ERROR; } auto abstract_tensor = utils::cast(abstract); + MS_CHECK_TRUE_MSG(abstract_tensor != nullptr, RET_ERROR, "cast ptr failed"); auto type_ptr = abstract_tensor->element()->GetTypeTrack(); MS_CHECK_TRUE_MSG(type_ptr != nullptr, RET_ERROR, "type_ptr is nullptr"); if (!utils::isa(abstract_tensor->BuildShape())) { @@ -390,6 +393,7 @@ int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType f return RET_ERROR; } auto tensor_value = tensor_info->cast(); + MS_CHECK_TRUE_MSG(tensor_value != nullptr, RET_ERROR, "cast ptr failed"); if (tensor_value->Size() >= kTensorListMinSize) { data_info->data_.resize(tensor_value->Size()); if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) != @@ -410,21 +414,21 @@ int RemoveIfDepend(const CNodePtr &cnode) { inputs.emplace_back(cnode->input(0)); for (size_t i = 1; i < cnode->inputs().size(); ++i) { AnfNodePtr inputNode = cnode->input(i); + MS_CHECK_TRUE_MSG(inputNode != nullptr, RET_NULL_PTR, "inputNode is nullptr"); if (!inputNode->isa()) { inputs.emplace_back(cnode->input(i)); continue; } auto depend_node = utils::cast(inputNode); + MS_CHECK_TRUE_MSG(depend_node != nullptr, RET_NULL_PTR, "depend_node is nullptr"); auto value_node = depend_node->input(0)->cast(); - if (value_node == nullptr) { - MS_LOG(ERROR) << "value node is invalid."; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid."); if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { has_depend = true; bool mask_out = (depend_node->inputs().size() == opt::kInputSizeThree); for (size_t j = 1; j < depend_node->inputs().size(); ++j) { AnfNodePtr depend_input_node = depend_node->input(j); + MS_CHECK_TRUE_MSG(depend_input_node != nullptr, RET_NULL_PTR, "depend_input_node is nullptr"); if (depend_input_node->isa()) { inputs.emplace_back(depend_input_node); if (mask_out) { @@ -450,16 +454,15 @@ int RemoveIfMakeTuple(const CNodePtr &cnode) { inputs.emplace_back(cnode->input(0)); for (size_t i = 1; i < cnode->inputs().size(); ++i) { AnfNodePtr input_node = cnode->input(i); + MS_CHECK_TRUE_MSG(input_node != nullptr, RET_NULL_PTR, "input_node is nullptr"); if (!input_node->isa()) { inputs.emplace_back(cnode->input(i)); continue; } auto make_tuple_node = utils::cast(input_node); + MS_CHECK_TRUE_MSG(make_tuple_node != nullptr, RET_NULL_PTR, "make_tuple_node is nullptr"); auto value_node = make_tuple_node->input(0)->cast(); - if (value_node == nullptr) { - MS_LOG(ERROR) << "value node is invalid."; - return RET_ERROR; - } + MS_CHECK_TRUE_MSG(value_node != nullptr, RET_NULL_PTR, "value node is invalid."); if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { has_make_tuple = true; diff --git a/mindspore/lite/tools/common/flag_parser.cc b/mindspore/lite/tools/common/flag_parser.cc index 58fc4d139a3..efde8cf379a 100644 --- a/mindspore/lite/tools/common/flag_parser.cc +++ b/mindspore/lite/tools/common/flag_parser.cc @@ -30,7 +30,7 @@ Option FlagParser::ParseFlags(int argc, const char *const *argv, bo } binName = GetFileName(argv[0]); - std::multimap> keyValues; + std::multimap> keyValues{}; for (int i = 1; i < argc; i++) { std::string tmp = argv[i]; Trim(&tmp); diff --git a/mindspore/lite/tools/common/func_graph_subgraph.cc b/mindspore/lite/tools/common/func_graph_subgraph.cc index afde82e33ef..a8fe2e07849 100644 --- a/mindspore/lite/tools/common/func_graph_subgraph.cc +++ b/mindspore/lite/tools/common/func_graph_subgraph.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include "src/common/log_adapter.h" #include "tools/common/node_util.h" #include "tools/common/graph_util.h" @@ -29,18 +28,43 @@ #include "nnacl/op_base.h" namespace mindspore::lite { -SubGraph::SubGraph(FuncGraphPtr belong_anf, std::string graph_name, const std::set &head_nodes) - : belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) { - InitSubGraphNode(head_nodes); - InitSubGraphInNode(); - InitSubGraphOutNode(); +int SubGraph::Init(const std::set &head_nodes) { + auto ret = InitSubGraphNode(head_nodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InitSubGraphNode failed"; + return RET_ERROR; + } + ret = InitSubGraphInNode(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InitSubGraphInNode failed"; + return RET_ERROR; + } + ret = InitSubGraphOutNode(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InitSubGraphOutNode failed"; + return RET_ERROR; + } + return RET_OK; } -void SubGraph::Reset(const std::set &nodes, const std::set &head_nodes) { +int SubGraph::Reset(const std::set &nodes, const std::set &head_nodes) { this->nodes_ = nodes; - InitSubGraphNode(head_nodes); - InitSubGraphInNode(); - InitSubGraphOutNode(); + auto ret = InitSubGraphNode(head_nodes); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InitSubGraphNode failed"; + return RET_ERROR; + } + ret = InitSubGraphInNode(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InitSubGraphInNode failed"; + return RET_ERROR; + } + ret = InitSubGraphOutNode(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "InitSubGraphOutNode failed"; + return RET_ERROR; + } + return RET_OK; } std::set SubGraph::GetNodes() const { return this->nodes_; } @@ -125,11 +149,11 @@ std::set SubGraph::GetOutputCNodes() const { return outputs; } -void SubGraph::InitSubGraphNode(const std::set &head_nodes) { +int SubGraph::InitSubGraphNode(const std::set &head_nodes) { MS_ASSERT(belong_anf_ != nullptr); MS_ASSERT(belong_anf_->manager() != nullptr); auto node_users = belong_anf_->manager()->node_users(); - std::queue q; + std::queue q{}; for (const auto &head_node : head_nodes) { if (head_node == nullptr) { continue; @@ -138,7 +162,7 @@ void SubGraph::InitSubGraphNode(const std::set &head_nodes) { } while (!q.empty()) { auto cur_node = q.front(); - MS_CHECK_PTR_IF_NULL(cur_node); + MS_CHECK_TRUE_MSG(cur_node != nullptr, RET_NULL_PTR, "cur_node is nullptr"); q.pop(); this->nodes_.insert(cur_node); // check output-cnode of cur-node only depend on cur-node @@ -153,11 +177,12 @@ void SubGraph::InitSubGraphNode(const std::set &head_nodes) { continue; } auto post_cnode = utils::cast(post_node); + MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "cast failed"); // return-node should not be include into subgraph absolutely // ut if (opt::CheckPrimitiveType(post_cnode, prim::kPrimReturn)) { continue; } - MS_CHECK_PTR_IF_NULL(post_cnode); + MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr"); bool non_depend = true; // check all inputs of output-cnode for (const auto &input : post_cnode->inputs()) { @@ -167,6 +192,7 @@ void SubGraph::InitSubGraphNode(const std::set &head_nodes) { // input cnode is not contained in subgraph if (utils::isa(input)) { auto input_cnode = utils::cast(input); + MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed"); if (this->nodes_.count(input_cnode) == 0) { non_depend = false; break; @@ -175,6 +201,7 @@ void SubGraph::InitSubGraphNode(const std::set &head_nodes) { // input parameter is a graph input if (utils::isa(input)) { auto input_parameter = utils::cast(input); + MS_CHECK_TRUE_MSG(input_parameter != nullptr, RET_NULL_PTR, "cast failed"); if (!input_parameter->has_default()) { non_depend = false; break; @@ -186,9 +213,10 @@ void SubGraph::InitSubGraphNode(const std::set &head_nodes) { } } } + return RET_OK; } -void SubGraph::InitSubGraphInNode() { +int SubGraph::InitSubGraphInNode() { MS_ASSERT(belong_anf_ != nullptr); MS_ASSERT(belong_anf_->manager() != nullptr); auto node_users = belong_anf_->manager()->node_users(); @@ -203,6 +231,7 @@ void SubGraph::InitSubGraphInNode() { } if (utils::isa(input)) { auto input_cnode = utils::cast(input); + MS_CHECK_TRUE_MSG(input_cnode != nullptr, false, "cast failed"); if (this->nodes_.count(input_cnode) == 0) { return true; } @@ -210,6 +239,7 @@ void SubGraph::InitSubGraphInNode() { // graph input or shared weight input // ut if (utils::isa(input)) { auto input_parameter = utils::cast(input); + MS_CHECK_TRUE_MSG(input_parameter != nullptr, false, "cast failed"); if (!input_parameter->has_default()) { return true; } @@ -223,9 +253,10 @@ void SubGraph::InitSubGraphInNode() { in_nodes_.insert(node); } } + return RET_OK; } -void SubGraph::InitSubGraphOutNode() { +int SubGraph::InitSubGraphOutNode() { MS_ASSERT(belong_anf_ != nullptr); MS_ASSERT(belong_anf_->manager() != nullptr); auto node_users = belong_anf_->manager()->node_users(); @@ -250,6 +281,7 @@ void SubGraph::InitSubGraphOutNode() { return true; } auto output_cnode = utils::cast(output_node); + MS_CHECK_TRUE_MSG(output_cnode != nullptr, false, "cast failed"); if (this->nodes_.count(output_cnode) == 0) { return true; } @@ -258,6 +290,7 @@ void SubGraph::InitSubGraphOutNode() { continue; out_nodes_.insert(node); } + return RET_OK; } bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) { @@ -272,7 +305,10 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) { auto new_nodes2 = subgraph->GetNodes(); new_nodes.insert(new_nodes2.begin(), new_nodes2.end()); new_nodes.insert(common_outputs.begin(), common_outputs.end()); - this->Reset(new_nodes, common_outputs); + if (this->Reset(new_nodes, common_outputs) != RET_OK) { + MS_LOG(ERROR) << "Reset failed"; + return false; + } return true; } @@ -280,7 +316,10 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) { auto new_nodes = this->GetNodes(); auto new_nodes2 = subgraph->GetNodes(); new_nodes.insert(new_nodes2.begin(), new_nodes2.end()); - this->Reset(new_nodes); + if (this->Reset(new_nodes) != RET_OK) { + MS_LOG(ERROR) << "Reset failed"; + return false; + } return true; } return false; @@ -290,8 +329,8 @@ bool SubGraph::MergeSubGraph(const SubGraphPtr &subgraph) { SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const { MS_ASSERT(belong_anf_ == nullptr); // find before subgraph's nodes - std::queue q; - std::set before_nodes; + std::queue q{}; + std::set before_nodes{}; for (const auto &node : this->GetInCNodes()) { for (const auto &in_cnode : lite::GetInputCNode(node)) { if (in_cnode == nullptr) { @@ -311,7 +350,11 @@ SubGraphPtr SubGraph::FindBeforeSubGraphInBelongAnf() const { } // construct before subgraph auto before_subgraph = std::make_shared(belong_anf_, this->name_ + "/before_subgraph"); - before_subgraph->Reset(before_nodes); + MS_CHECK_TRUE_MSG(before_subgraph != nullptr, nullptr, "before_subgraph is nullptr"); + if (before_subgraph->Reset(before_nodes) != RET_OK) { + MS_LOG(ERROR) << "Reset failed"; + return nullptr; + } return before_subgraph; } @@ -325,8 +368,8 @@ SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const { return nullptr; } // find after subgraph's nodes - std::queue q; - std::set after_nodes; + std::queue q{}; + std::set after_nodes{}; auto output_node = belong_anf_->output(); if (!utils::isa(output_node)) { MS_LOG(ERROR) << "Output node of anf should be a cnode: " << output_node->fullname_with_scope(); @@ -348,7 +391,12 @@ SubGraphPtr SubGraph::FindAfterSubGraphInBelongAnf() const { } // construct before subgraph auto after_subgraph = std::make_shared(belong_anf_, this->name_ + "/after_subgraph"); - after_subgraph->Reset(after_nodes); + MS_CHECK_TRUE_MSG(after_subgraph != nullptr, nullptr, "after_subgraph is nullptr"); + if (after_subgraph->Reset(after_nodes) != RET_OK) { + MS_LOG(ERROR) << "Reset failed"; + return nullptr; + } + return after_subgraph; } @@ -366,6 +414,7 @@ int SubGraph::CreatePartialInBelongAnf() { } // create func_graph of partial FuncGraphPtr func_graph = std::make_shared(); + MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_NULL_PTR, "func_graph is nullptr"); auto manager = belong_anf_->manager(); manager->AddFuncGraph(func_graph); func_graph->set_attr("graph_name", MakeValue(graph_name)); @@ -373,12 +422,20 @@ int SubGraph::CreatePartialInBelongAnf() { // create cnode and parameter for func_graph of partial std::vector partial_inputs; std::map partial_inputs_and_subgraph_input_map; - CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map); - CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map); + auto ret = CreateParameterForPartialSubGraph(func_graph, &partial_inputs, &partial_inputs_and_subgraph_input_map); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "CreateParameterForPartialSubGraph failed"; + return ret; + } + ret = CreateCNodeForPartialSubGraph(func_graph, partial_inputs_and_subgraph_input_map); + if (ret != RET_OK) { + MS_LOG(DEBUG) << "CreateCNodeForPartialSubGraph failed"; + return ret; + } // add return for func_graph of partial auto sub_graph_outputs = this->GetOutCNodes(); MS_ASSERT(!sub_graph_outputs.empty()); - auto ret = SetFuncGraphOutput(func_graph, sub_graph_outputs); + ret = SetFuncGraphOutput(func_graph, sub_graph_outputs); if (ret != RET_OK) { MS_LOG(DEBUG) << "Set subgraph output failed"; return ret; @@ -386,8 +443,11 @@ int SubGraph::CreatePartialInBelongAnf() { // create partial cnode auto partial_prim = std::make_shared(); auto graph_value_node = NewValueNode(func_graph); + MS_CHECK_TRUE_MSG(partial_prim != nullptr, RET_NULL_PTR, "partial_prim is nullptr"); + MS_CHECK_TRUE_MSG(graph_value_node != nullptr, RET_NULL_PTR, "graph_value_node is nullptr"); partial_inputs.insert(partial_inputs.begin(), graph_value_node); auto partial_cnode = belong_anf_->NewCNode(partial_prim, partial_inputs); + MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "partial_cnode is nullptr"); partial_cnode->set_fullname_with_scope(graph_name + "/partial"); for (size_t i = 0; i < partial_inputs.size(); ++i) { const auto &input = partial_inputs.at(i); @@ -396,6 +456,7 @@ int SubGraph::CreatePartialInBelongAnf() { // create call cnode std::vector call_node_inputs{partial_cnode}; auto call_cnode = belong_anf_->NewCNode(call_node_inputs); + MS_CHECK_TRUE_MSG(call_cnode != nullptr, RET_NULL_PTR, "call_cnode is nullptr"); call_cnode->set_fullname_with_scope(graph_name + "/call"); // replace belong-graph's output auto return_node = belong_anf_->get_return(); @@ -412,7 +473,7 @@ int SubGraph::SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set *partial_inputs, std::map *partial_inputs_and_subgraph_input_map) { MS_ASSERT(sub_graph != nullptr); @@ -436,6 +497,7 @@ void SubGraph::CreateParameterForPartialSubGraph( // create subgraph input parameter from cnode and record partial inputs if (utils::isa(input)) { auto input_cnode = utils::cast(input); + MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "cast ptr failed"); if (this->GetNodes().count(input_cnode) > 0) { continue; } @@ -450,6 +512,7 @@ void SubGraph::CreateParameterForPartialSubGraph( auto node_users = this->belong_anf_->manager()->node_users(); if (utils::isa(input)) { auto parameter = utils::cast(input); + MS_CHECK_TRUE_MSG(parameter != nullptr, RET_NULL_PTR, "cast ptr failed"); // graph input: create a parameter if (!parameter->has_default()) { auto new_parameter = sub_graph->add_parameter(); @@ -475,9 +538,10 @@ void SubGraph::CreateParameterForPartialSubGraph( } } } + return RET_OK; } -void SubGraph::CreateCNodeForPartialSubGraph( +int SubGraph::CreateCNodeForPartialSubGraph( const FuncGraphPtr &sub_graph, const std::map &partial_inputs_and_subgraph_input_map) { MS_ASSERT(sub_graph != nullptr); // move cnode from belong_graph to subgraph @@ -500,6 +564,7 @@ void SubGraph::CreateCNodeForPartialSubGraph( } this->belong_anf_->DropNode(node); } + return RET_OK; } int SubGraph::ApplySubGraph() { @@ -546,7 +611,10 @@ int SubGraph::ApplySubGraph() { MS_ASSERT(after_partial_cnode != nullptr); subgraph_nodes.insert(after_partial_cnode); subgraph_nodes.insert(call_cnode); - this->Reset(subgraph_nodes); + if (this->Reset(subgraph_nodes) != RET_OK) { + MS_LOG(ERROR) << "Reset failed"; + return false; + } // create subgraph partial // add partial to main subgraph ret = this->CreatePartialInBelongAnf(); if (ret != RET_OK) { diff --git a/mindspore/lite/tools/common/func_graph_subgraph.h b/mindspore/lite/tools/common/func_graph_subgraph.h index 5a590d26da7..56b39f1d082 100644 --- a/mindspore/lite/tools/common/func_graph_subgraph.h +++ b/mindspore/lite/tools/common/func_graph_subgraph.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "src/common/log_adapter.h" #include "include/errorcode.h" #include "ir/anf.h" @@ -32,9 +33,11 @@ class SubGraph; using SubGraphPtr = std::shared_ptr; class SubGraph { public: - explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "", const std::set &head_nodes = {}); + explicit SubGraph(FuncGraphPtr belong_anf, std::string graph_name = "") + : belong_anf_(std::move(belong_anf)), name_(std::move(graph_name)) {} - void Reset(const std::set &nodes, const std::set &head_nodes = {}); + int Init(const std::set &head_nodes = {}); + int Reset(const std::set &nodes, const std::set &head_nodes = {}); bool MergeSubGraph(const SubGraphPtr &subgraph); @@ -48,19 +51,19 @@ class SubGraph { std::set GetInputCNodes() const; std::set GetOutputCNodes() const; // init subgraph methods - void InitSubGraphNode(const std::set &head_nodes); - void InitSubGraphInNode(); - void InitSubGraphOutNode(); + int InitSubGraphNode(const std::set &head_nodes); + int InitSubGraphInNode(); + int InitSubGraphOutNode(); // merge subgraph methods std::set FindCommonOutputs(const SubGraphPtr &subgraph) const; bool IfDependOnSameNode(const SubGraphPtr &subgraph) const; // apply subgraph methods SubGraphPtr FindBeforeSubGraphInBelongAnf() const; SubGraphPtr FindAfterSubGraphInBelongAnf() const; - void CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector *partial_inputs, - std::map *partial_inputs_and_subgraph_input_map); - void CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph, - const std::map &partial_inputs_and_subgraph_input_map); + int CreateParameterForPartialSubGraph(const FuncGraphPtr &sub_graph, std::vector *partial_inputs, + std::map *partial_inputs_and_subgraph_input_map); + int CreateCNodeForPartialSubGraph(const FuncGraphPtr &sub_graph, + const std::map &partial_inputs_and_subgraph_input_map); int CreatePartialInBelongAnf(); static int SetFuncGraphOutput(const FuncGraphPtr &graph, const std::set &outputs); diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 3dc8aff2227..c42af780b7e 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -27,6 +27,7 @@ #include "mindspore/core/ops/switch.h" #include "mindspore/core/ops/call.h" #include "mindspore/core/ops/fusion/partial_fusion.h" +#include "nnacl/op_base.h" namespace mindspore { namespace lite { @@ -332,6 +333,7 @@ STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) { size_t GetCNodeOutputsSize(const std::shared_ptr &anf_node, bool train_flag) { MS_ASSERT(anf_node != nullptr); auto cnode = anf_node->cast(); + MS_ASSERT(cnode != nullptr); if (train_flag && (opt::CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || opt::CheckPrimitiveType(cnode, prim::kPrimAdam))) { return 1; @@ -350,6 +352,7 @@ bool IsPartialFusion(const AnfNodePtr &node) { } if (node->isa()) { auto cnode = node->cast(); + MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed"); auto vnode_value = cnode->input(0)->cast()->value(); return GetValue(vnode_value)->name() == "PartialFusion"; } @@ -364,6 +367,7 @@ bool IsCall(const AnfNodePtr &node) { return false; } auto cnode = node->cast(); + MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cast ptr failed"); if (cnode->inputs().empty()) { return false; } @@ -400,19 +404,25 @@ bool IsMakeTuple(const AnfNodePtr &node) { ValueNodePtr GetPartialFusionPrim() { auto partial_prim = std::make_shared(); + MS_CHECK_TRUE_MSG(partial_prim != nullptr, nullptr, "partial_prim is nullptr"); ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); + MS_CHECK_TRUE_MSG(partial_anf_prim != nullptr, nullptr, "partial_anf_prim is nullptr"); return partial_anf_prim; } ValueNodePtr GetSwitchAnfPrim() { auto switch_prim = std::make_shared(); + MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr"); ValueNodePtr switch_anf_prim = NewValueNode(switch_prim); + MS_CHECK_TRUE_MSG(switch_prim != nullptr, nullptr, "switch_prim is nullptr"); return switch_anf_prim; } ValueNodePtr GetCallAnfPrim() { auto call_prim = std::make_shared(); + MS_CHECK_TRUE_MSG(call_prim != nullptr, nullptr, "call_prim is nullptr"); ValueNodePtr call_anf_prim = NewValueNode(call_prim); + MS_CHECK_TRUE_MSG(call_anf_prim != nullptr, nullptr, "call_anf_prim is nullptr"); return call_anf_prim; } } // namespace lite diff --git a/mindspore/lite/tools/common/parse_config_utils.cc b/mindspore/lite/tools/common/parse_config_utils.cc index 76aa730d86e..56230f7515a 100644 --- a/mindspore/lite/tools/common/parse_config_utils.cc +++ b/mindspore/lite/tools/common/parse_config_utils.cc @@ -19,6 +19,7 @@ #include "mindspore/lite/tools/common/string_util.h" #include "src/common/log_adapter.h" #include "include/errorcode.h" + namespace mindspore { namespace lite { int ReadFileToIfstream(const std::string &file_path, std::ifstream *ifstream) { diff --git a/mindspore/lite/tools/common/tensor_util.cc b/mindspore/lite/tools/common/tensor_util.cc index 99e2a8ac236..ef844107706 100644 --- a/mindspore/lite/tools/common/tensor_util.cc +++ b/mindspore/lite/tools/common/tensor_util.cc @@ -18,6 +18,7 @@ #include "src/common/utils.h" #include "tools/common/graph_util.h" #include "abstract/utils.h" +#include "nnacl/op_base.h" namespace mindspore::lite { std::unique_ptr GetTensorQuantParam(const std::unique_ptr &tensor) { @@ -131,6 +132,7 @@ std::unique_ptr CreateTensorTFromTensorInfo(const tensor::Tenso return nullptr; } auto schema_tensor = std::make_unique(); + MS_CHECK_TRUE_MSG(schema_tensor != nullptr, nullptr, "schema_tensor is nullptr"); schema_tensor->name = tensor_name; auto ret = UpdateTensorTFromTensorInfo(tensor_info, &schema_tensor); if (ret != RET_OK) { diff --git a/mindspore/lite/tools/optimizer/graph/add_tensor_array.cc b/mindspore/lite/tools/optimizer/graph/add_tensor_array.cc index 62249a258d3..47f0c1b8f0b 100644 --- a/mindspore/lite/tools/optimizer/graph/add_tensor_array.cc +++ b/mindspore/lite/tools/optimizer/graph/add_tensor_array.cc @@ -37,6 +37,7 @@ static bool IsSupportedNode(const BaseRef &n) { }; if (utils::isa(n)) { auto anf_node = utils::cast(n); + MS_ASSERT(anf_node != nullptr); return std::any_of(support_list.begin(), support_list.end(), [&anf_node](const auto &primitive) { return CheckPrimitiveType(anf_node, primitive); }); } @@ -95,6 +96,7 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens } auto new_return_node = func_graph->NewCNode({NewValueNode(return_prim_ptr), make_tuple_cnode}); new_return_node->set_fullname_with_scope(return_cnode->fullname_with_scope()); + MS_ASSERT(new_return_node != nullptr); func_graph->set_return(new_return_node); MS_ASSERT(new_return_node != nullptr); @@ -103,7 +105,9 @@ static int SetGraphOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &tens const BaseRef AddTensorArray::DefinePattern() const { auto support_detect = std::make_shared(IsSupportedNode); + MS_ASSERT(support_detect != nullptr); auto inputs_var = std::make_shared(); + MS_ASSERT(inputs_var != nullptr); return VectorRef({support_detect, inputs_var}); } @@ -137,6 +141,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A } auto abstract_tensor = utils::cast(abstract); + MS_ASSERT(abstract_tensor != nullptr); if (!utils::isa(abstract_tensor->GetValueTrack())) { // input node not complete infershape MS_LOG(DEBUG) << "Value of abstract is not tensor::Tensor, indicate that infershape has failed"; return nullptr; @@ -168,6 +173,7 @@ const AnfNodePtr AddTensorArray::Process(const FuncGraphPtr &func_graph, const A NewValueNode(tensor_array), NewValueNode(kDefaultNumTensors), }); + MS_ASSERT(tensor_array_node != nullptr); tensor_array_node->set_abstract(abstract->Clone()); tensor_array_node->set_fullname_with_scope(cnode->fullname_with_scope() + "_tensor_array"); diff --git a/mindspore/lite/tools/optimizer/graph/control_flow_pass.cc b/mindspore/lite/tools/optimizer/graph/control_flow_pass.cc index eed58ae08c5..1972abbd2d9 100644 --- a/mindspore/lite/tools/optimizer/graph/control_flow_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/control_flow_pass.cc @@ -168,6 +168,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow } visited_nodes->insert(node); auto cnode = utils::cast(node); + MS_CHECK_TRUE_MSG(cnode != nullptr, RET_ERROR, "cast ptr failed"); for (size_t i = 0; i < cnode->inputs().size(); i++) { auto input = cnode->input(i); if (visited_nodes->find(input) == visited_nodes->end()) { @@ -191,6 +192,7 @@ int ControlFlowPass::SplitGraph(const FuncGraphPtr &fg, AnfNodePtr *control_flow int ControlFlowPass::CreateAfterGraph(const FuncGraphPtr &main_fg, const std::vector &remain_nodes, const CNodePtr &aim_cnode, FuncGraphPtr *after_fg) { *after_fg = std::make_shared(); + MS_CHECK_TRUE_MSG(*after_fg != nullptr, lite::RET_NULL_PTR, "*after_fg is nullptr"); auto manager = main_fg->manager(); manager->AddFuncGraph(*after_fg); (*after_fg)->set_attr("fmk", MakeValue(static_cast(converter::kFmkTypeTf))); @@ -282,6 +284,7 @@ int ControlFlowPass::CreateWhileCondCallNode( int ControlFlowPass::CreateWhileBodyPartialNode(const FuncGraphPtr &cond_fg, const CNodePtr &while_cnode, CNodePtr *body_partial_node) { auto body_vnode = while_cnode->input(kWhileBodyIndex); + MS_CHECK_TRUE_MSG(body_vnode != nullptr, lite::RET_NULL_PTR, "body_vnode is nullptr"); auto body_fg = GetValueNode>(body_vnode); if (body_fg == nullptr) { MS_LOG(ERROR) << "Get value as func_graph failed."; @@ -469,13 +472,14 @@ int ControlFlowPass::ProcessWhileOp(const FuncGraphPtr &fg, const std::set switch_node_inputs = {switch_anf_primitive, cond_fg->output(), body_partial_node, after_partial_cnode}; auto switch_cnode = cond_fg->NewCNode(switch_node_inputs); + MS_CHECK_TRUE_MSG(switch_cnode != nullptr, RET_ERROR, "NewCnode failed"); switch_cnode->set_fullname_with_scope("while-Switch-" + cond_fg->get_attr("graph_name")->ToString()); // insert call node @@ -590,6 +594,7 @@ int ControlFlowPass::CreateIfPartialNode(const FuncGraphPtr &fg, const size_t &i after_partial_cnode_inputs.push_back(then_fg->output()); } else { auto then_fg_output = then_fg->output()->cast(); + MS_CHECK_TRUE_MSG(then_fg_output != nullptr, RET_ERROR, "cast ptr failed"); for (size_t i = kCNodeFirstInputIndex; i < then_fg_output->inputs().size(); ++i) { after_partial_cnode_inputs.push_back(then_fg_output->input(i)); } diff --git a/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc b/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc index 1c930855a99..a55f0d7ac43 100644 --- a/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc +++ b/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.cc @@ -84,6 +84,7 @@ STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNode return lite::RET_ERROR; } auto cur_node_post = cur_node_user.first->cast(); + MS_CHECK_TRUE_MSG(cur_node_post != nullptr, RET_ERROR, "cast ptr failed"); if (middle_nodes->find(cur_node_post) != middle_nodes->end() || out_nodes->find(cur_node_post) != out_nodes->end()) { continue; @@ -138,6 +139,7 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set< continue; } auto middle_node_prim = GetValueNode(middle_cnode->input(0)); + MS_CHECK_TRUE_MSG(middle_node_prim != nullptr, false, "GetValueNode failed"); if (!transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) { return false; } @@ -145,30 +147,30 @@ bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set< return true; } -void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type, - bool train_flag, FormatTransNodeType trans_type) { +int ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type, + bool train_flag, FormatTransNodeType trans_type) { MS_ASSERT(cnode != nullptr); if (utils::isa(cnode->input(index))) { - return; + return lite::RET_OK; } lite::DataInfo data_info; - int status; + int status = 0; if (utils::isa(cnode->input(index))) { auto input_node = cnode->input(index)->cast(); - MS_ASSERT(input_node != nullptr); + MS_CHECK_TRUE_MSG(input_node != nullptr, lite::RET_ERROR, "input_node is nullptr"); if (!input_node->has_default()) { - return; + return lite::RET_OK; } status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info); } else { status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info); } if (status != lite::RET_OK) { - return; + return lite::RET_ERROR; } if (data_info.shape_.empty() || (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) { - return; + return lite::RET_OK; } ShapeVector expand_shape(data_info.shape_.begin(), data_info.shape_.end()); if (data_info.shape_.size() == 1) { @@ -180,22 +182,24 @@ void ConvertTensorToNCOrNH(const FuncGraphPtr &func_graph, const CNodePtr &cnode } auto tensor = std::make_shared(static_cast(data_info.data_type_), expand_shape, data_info.data_.data(), data_info.data_.size()); + MS_CHECK_TRUE_MSG(tensor != nullptr, lite::RET_ERROR, "tensor is nullptr"); if (trans_type == kNHWC2NCHW) { (void)TransFilterFormat(tensor, schema::Format_KHWC, schema::Format_KCHW); } else { (void)TransFilterFormat(tensor, schema::Format_KCHW, schema::Format_KHWC); } auto param_node = func_graph->add_parameter(); - MS_CHECK_TRUE_MSG(param_node != nullptr, , "add_parameter failed"); + MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add_parameter failed"); param_node->set_name(cnode->input(index)->fullname_with_scope()); status = lite::InitParameterFromTensorInfo(param_node, tensor); if (status != RET_OK) { MS_LOG(ERROR) << "init parameter from tensor info failed"; - return; + return lite::RET_ERROR; } auto tr = func_graph->manager()->Transact(); tr.SetEdge(cnode, index, param_node); tr.Commit(); + return lite::RET_OK; } } // namespace @@ -273,6 +277,7 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, const std::vector &perm) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); auto prim_node = cnode->input(0); + MS_CHECK_TRUE_MSG(prim_node != nullptr, lite::RET_ERROR, "prim_node is nullptr"); auto prim = GetValueNode(prim_node); MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed"); auto &specify_nhwc_op_map = GetNHWCOpMap(); @@ -346,7 +351,11 @@ STATUS DecreaseTransposeAlgo::InsertPreTransNode(const FuncGraphPtr &func_graph, return lite::RET_ERROR; } } - ModifyCNodeFormat(cnode, trans_insert_info->pre_); + status = ModifyCNodeFormat(cnode, trans_insert_info->pre_); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "ModifyCNodeFormat failed."; + return lite::RET_ERROR; + } status = node_infer_shape_.InferShape(cnode); if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { MS_LOG(ERROR) << "infer shape failed."; @@ -439,14 +448,22 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap continue; } for (size_t i = 1; i < middle_cnode->size(); ++i) { - ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_); + status = ConvertTensorToNCOrNH(func_graph, middle_cnode, i, fmk_type_, train_flag_, trans_info.post_); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "ConvertTensorToNCOrNH failed."; + return lite::RET_ERROR; + } } status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode, trans_info.post_); if (status != lite::RET_OK) { MS_LOG(ERROR) << "change op attr failed."; return lite::RET_ERROR; } - ModifyCNodeFormat(middle_cnode, trans_info.post_); + status = ModifyCNodeFormat(middle_cnode, trans_info.post_); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "ModifyCNodeFormat failed."; + return lite::RET_ERROR; + } status = node_infer_shape_.InferShape(middle_cnode); if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { MS_LOG(ERROR) << "infer shape failed."; @@ -456,7 +473,7 @@ STATUS DecreaseTransposeAlgo::HandleGraphMultiNode(const FuncGraphPtr &func_grap return lite::RET_OK; } -void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { +int DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { MS_ASSERT(cnode != nullptr && sub_graph != nullptr); auto sub_inputs = sub_graph->get_inputs(); sub_inputs_map_[sub_graph] = sub_inputs; @@ -475,6 +492,7 @@ void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGr MS_ASSERT(out_cnode != nullptr); MS_ASSERT(trans_cnode != nullptr); auto out_prim = GetValueNode(out_cnode->input(0)); + MS_CHECK_TRUE_MSG(out_prim != nullptr, lite::RET_ERROR, "GetValueNode failed"); if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue(out_prim->GetAttr(kInferDone))) { param_node->abstract()->set_shape(std::make_shared(shape_vec)); } @@ -499,9 +517,10 @@ void DecreaseTransposeAlgo::SetSubGraphInput(const CNodePtr &cnode, const FuncGr } } } + return lite::RET_OK; } -void DecreaseTransposeAlgo::ResetSubGraphInput() { +int DecreaseTransposeAlgo::ResetSubGraphInput() { for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { auto &sub_graph = iter->first; auto &sub_inputs = iter->second; @@ -509,7 +528,7 @@ void DecreaseTransposeAlgo::ResetSubGraphInput() { MS_ASSERT(manager != nullptr); for (auto &sub_input : sub_inputs) { auto param_node = sub_graph->add_parameter(); - MS_CHECK_TRUE_MSG(param_node != nullptr, , "add parameter failed"); + MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "add parameter failed"); param_node->set_abstract(sub_input->abstract()->Clone()); param_node->set_name(sub_input->fullname_with_scope()); manager->Replace(sub_input, param_node); @@ -518,9 +537,10 @@ void DecreaseTransposeAlgo::ResetSubGraphInput() { sub_param_input->set_default_param(nullptr); } } + return lite::RET_OK; } -void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { +int DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { MS_ASSERT(cnode != nullptr && sub_graph != nullptr); auto return_node = sub_graph->get_return(); MS_ASSERT(return_node != nullptr); @@ -548,11 +568,13 @@ void DecreaseTransposeAlgo::SetSubGraphOutput(const CNodePtr &cnode, const FuncG trans_cnode->set_fullname_with_scope(trans_input_name); } return_node->set_inputs(origin_input); + return lite::RET_OK; } -void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { +int DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { MS_ASSERT(cnode != nullptr && sub_graph != nullptr); auto return_node = sub_graph->get_return(); + MS_CHECK_TRUE_MSG(return_node != nullptr, lite::RET_ERROR, "return_node is nullptr"); auto origin_inputs = return_node->inputs(); lite::RemoveIfDepend(return_node); lite::RemoveIfMakeTuple(return_node); @@ -560,7 +582,7 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun bool infer_done = true; for (size_t i = 1; i < return_node->size(); ++i) { auto abstract_base = GetCNodeInputAbstract(return_node, i); - MS_CHECK_TRUE_MSG(abstract_base != nullptr, , "GetCNodeInputAbstract failed"); + MS_CHECK_TRUE_MSG(abstract_base != nullptr, lite::RET_ERROR, "GetCNodeInputAbstract failed"); abstract_list.emplace_back(abstract_base->Clone()); auto abstract_tensor = abstract_base->cast(); MS_ASSERT(abstract_tensor != nullptr); @@ -572,11 +594,12 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun } if (utils::isa(return_node->input(i))) { auto input_cnode = return_node->input(i)->cast(); + MS_CHECK_TRUE_MSG(input_cnode != nullptr, lite::RET_ERROR, "input_cnode is nullptr"); if (CheckPrimitiveType(input_cnode, prim::kPrimTupleGetItem)) { input_cnode = input_cnode->input(1)->cast(); } auto input_prim = GetValueNode(input_cnode->input(0)); - MS_CHECK_TRUE_MSG(input_prim != nullptr, , "GetValueNode failed"); + MS_CHECK_TRUE_MSG(input_prim != nullptr, lite::RET_ERROR, "GetValueNode failed"); if (input_prim->GetAttr(kInferDone) == nullptr || !GetValue(input_prim->GetAttr(kInferDone))) { infer_done = false; } @@ -592,22 +615,25 @@ void DecreaseTransposeAlgo::SetSubGraphAbstract(const CNodePtr &cnode, const Fun cnode->set_abstract(abstract_list.front()); } auto prim = GetValueNode(cnode->input(0)); - MS_CHECK_TRUE_MSG(prim != nullptr, , "GetValueNode Failed"); + MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_ERROR, "GetValueNode Failed"); prim->AddAttr(kInferDone, MakeValue(infer_done)); + + return lite::RET_OK; } -void DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) { +int DecreaseTransposeAlgo::ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type) { MS_ASSERT(cnode != nullptr); if (pre_trans_type == kNONE) { - return; + return lite::RET_OK; } auto primitive = GetValueNode(cnode->input(0)); - MS_CHECK_TRUE_MSG(primitive != nullptr, , "GetValueNode Failed"); + MS_CHECK_TRUE_MSG(primitive != nullptr, lite::RET_ERROR, "GetValueNode Failed"); if (pre_trans_type == kNHWC2NCHW) { primitive->AddAttr(ops::kFormat, MakeValue(mindspore::NCHW)); } else { primitive->AddAttr(ops::kFormat, MakeValue(mindspore::NHWC)); } + return lite::RET_OK; } bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph) { @@ -618,7 +644,7 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun return false; } auto node_list = TopoSort(func_graph->get_return()); - int status; + int status = 0; for (auto &node : node_list) { if (!utils::isa(node)) { continue; @@ -634,18 +660,38 @@ bool DecreaseTransposeAlgo::DecreaseTransposeForSingleOp(const FuncGraphPtr &fun lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } - SetSubGraphInput(cnode, sub_func_graph); + auto ret = SetSubGraphInput(cnode, sub_func_graph); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphInput failed"; + return false; + } (void)DecreaseTransposeForSingleOp(sub_func_graph); - SetSubGraphOutput(cnode, sub_func_graph); + ret = SetSubGraphOutput(cnode, sub_func_graph); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphOutput failed"; + return false; + } sub_func_graph = GetValueNode(cnode->input(kInputIndexTwo)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); return false; } - SetSubGraphInput(cnode, sub_func_graph); + ret = SetSubGraphInput(cnode, sub_func_graph); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphInput failed"; + return false; + } (void)DecreaseTransposeForSingleOp(sub_func_graph); - SetSubGraphOutput(cnode, sub_func_graph); - SetSubGraphAbstract(cnode, sub_func_graph); + ret = SetSubGraphOutput(cnode, sub_func_graph); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphOutput failed"; + return false; + } + ret = SetSubGraphAbstract(cnode, sub_func_graph); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphAbstract failed"; + return false; + } continue; } auto prim = GetValueNode(cnode->input(0)); @@ -726,7 +772,12 @@ bool DecreaseTransposeAlgo::Run(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "run local trans insert optimizer failed."; return false; } - ResetSubGraphInput(); + + auto ret = ResetSubGraphInput(); + if (ret != lite::RET_OK) { + MS_LOG(ERROR) << "ResetSubGraphInput failed."; + return false; + } // if input format of several ops surrounded only by transpose op all can be NHWC, // we can delete these transpose ops, and at the same time, transform these middle ops. if (!DecreaseTransposeForMultiOp(func_graph)) { diff --git a/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.h b/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.h index a272f74f4e2..6caeb90e4c7 100644 --- a/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.h +++ b/mindspore/lite/tools/optimizer/graph/decrease_transpose_algo.h @@ -51,11 +51,11 @@ class DecreaseTransposeAlgo : public Pass { STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::set *visit_transposes); STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); - void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); - void ResetSubGraphInput(); - void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); - void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); - void ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type); + int SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + int ResetSubGraphInput(); + int SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + int SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + int ModifyCNodeFormat(const CNodePtr &cnode, FormatTransNodeType pre_trans_type); FmkType fmk_type_{converter::kFmkTypeMs}; bool train_flag_{false}; NodeInferShape node_infer_shape_; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index de4bb919612..d58e3cf123e 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -111,7 +111,11 @@ bool InferShapePass::Run(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "infer shape failed."; return false; } - return ResetSubGraphInput(); + if (ResetSubGraphInput() != lite::RET_OK) { + MS_LOG(ERROR) << "ResetSubGraphInput failed."; + return false; + } + return true; } bool InferShapePass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { @@ -187,7 +191,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "subgraph infer shape failed."; return false; } - SetSubGraphOutput(cnode, sub_func_graph); + if (SetSubGraphOutput(cnode, sub_func_graph)) { + MS_LOG(ERROR) << "SetSubGraphOutput failed."; + return false; + } sub_func_graph = GetValueNode(cnode->input(kInputIndexTwo)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); @@ -202,7 +209,10 @@ STATUS InferShapePass::InferProcess(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "subgraph infer shape failed."; return false; } - SetSubGraphOutput(cnode, sub_func_graph); + if (SetSubGraphOutput(cnode, sub_func_graph) != lite::RET_OK) { + MS_LOG(ERROR) << "SetSubGraphOutput failed."; + return false; + } ret = SetSubGraphAbstract(cnode, sub_func_graph); if (ret != RET_OK) { MS_LOG(ERROR) << "SetSubGraphAbstract failed: " << ret; @@ -278,7 +288,7 @@ STATUS InferShapePass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPt return RET_OK; } -void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { +STATUS InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { MS_ASSERT(cnode != nullptr && sub_graph != nullptr); auto return_node = sub_graph->get_return(); MS_ASSERT(return_node != nullptr); @@ -307,6 +317,7 @@ void InferShapePass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr trans_cnode->set_fullname_with_scope(trans_input_name); } return_node->set_inputs(origin_input); + return lite::RET_OK; } STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { @@ -360,7 +371,7 @@ STATUS InferShapePass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGrap return RET_OK; } -bool InferShapePass::ResetSubGraphInput() { +int InferShapePass::ResetSubGraphInput() { for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { auto &sub_graph = iter->first; auto &sub_inputs = iter->second; @@ -368,7 +379,7 @@ bool InferShapePass::ResetSubGraphInput() { MS_ASSERT(manager != nullptr); for (auto &sub_input : sub_inputs) { auto param_node = sub_graph->add_parameter(); - MS_CHECK_TRUE_MSG(param_node != nullptr, false, "Add parameter Failed"); + MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "Add parameter Failed"); param_node->set_abstract(sub_input->abstract()->Clone()); param_node->set_name(sub_input->fullname_with_scope()); manager->Replace(sub_input, param_node); @@ -377,7 +388,7 @@ bool InferShapePass::ResetSubGraphInput() { sub_param_input->set_default_param(nullptr); } } - return true; + return lite::RET_OK; } } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.h b/mindspore/lite/tools/optimizer/graph/infershape_pass.h index 5ae4632d046..bac4bd5addd 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.h +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.h @@ -36,9 +36,9 @@ class InferShapePass : public Pass { bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph); STATUS InferProcess(const FuncGraphPtr &func_graph); STATUS SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); - void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + STATUS SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); STATUS SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); - bool ResetSubGraphInput(); + int ResetSubGraphInput(); FmkType fmk_type_{converter::kFmkTypeMs}; bool train_flag_{false}; diff --git a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc index 543d346f520..27a48a21bcd 100644 --- a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc +++ b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc @@ -90,16 +90,16 @@ std::vector TransformOpAxesAttr(const std::vector &origin_axes, Format return cur_axes; } -void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, - const std::vector &axes, FormatTransNodeType trans_type, - NodeInferShape *node_infer_shape) { +int TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index, + const std::vector &axes, FormatTransNodeType trans_type, + NodeInferShape *node_infer_shape) { MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr); if (input_index >= cnode->size() || axes.empty()) { - return; + return lite::RET_ERROR; } auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index); if (origin_input.size() != axes.size()) { - return; + return lite::RET_ERROR; } std::vector cur_input; for (int dim = 0; dim < static_cast(kInputSizeFour); ++dim) { @@ -119,7 +119,9 @@ void TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, } } auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope()); + MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "BuildIntVecParameterNode failed"); func_graph->manager()->Replace(cnode->input(input_index), param_node); + return lite::RET_OK; } STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type, @@ -266,6 +268,7 @@ STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, Form } int element_num = shape.front(); auto prim = GetValueNode>(cnode->input(0)); + MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed"); std::vector axes; if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) { for (int index = 0; index < element_num; ++index) { @@ -314,6 +317,7 @@ STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode auto cur_axes = TransformOpAxesAttr(axes, trans_type); auto param_node = BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope()); + MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed"); func_graph->manager()->Replace(cnode->input(kInputIndexFour), param_node); return lite::RET_OK; } diff --git a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc index 8d3e7a1058b..10e847f801c 100644 --- a/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unused_transpose_node_remove_pass.cc @@ -72,6 +72,7 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) { } if (CheckPrimitiveType(node, prim::kPrimTranspose)) { auto transpose_cnode = node->cast(); + MS_ASSERT(transpose_cnode != nullptr); if (!CheckPrimitiveType(transpose_cnode->input(kTransposeInput), prim::kPrimConv2DFusion)) { continue; } @@ -90,6 +91,7 @@ bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) { continue; } auto transpose_cnode = conv_node->input(kTransposeInput)->cast(); + MS_ASSERT(transpose_cnode != nullptr); auto perm = GetTransposePerm(transpose_cnode); if (perm == kPermNHWC) { manager->Replace(transpose_cnode, transpose_cnode->input(1));