!25923 Move IsRealKernel and IsRealCNodeKernel from AnfAlgo to AnfUtils

Merge pull request !25923 from DeshiChen/1101_anfalgo
This commit is contained in:
i-robot 2021-11-10 01:56:07 +00:00 committed by Gitee
commit d4c8dde6d2
77 changed files with 204 additions and 204 deletions

View File

@ -262,7 +262,7 @@ class CNodeDecoder {
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->set_feature_map_flag(true);
}
if (AnfAlgo::IsRealCNodeKernel(cnode_)) {
if (AnfUtils::IsRealCNodeKernel(cnode_)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode_);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode_);
}

View File

@ -653,7 +653,7 @@ void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr>
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
for (auto const &node : node_lists) {
if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
if (!AnfUtils::IsRealKernel(node) || !node->isa<CNode>()) {
continue;
}
auto cnode = node->cast<CNodePtr>();

View File

@ -369,7 +369,7 @@ void TbeKernelCompileManager::GetAllTbeNodes(const std::shared_ptr<session::Kern
auto all_nodes = kernel_graph->execution_order();
for (const auto &anf_node : all_nodes) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
if (!AnfUtils::IsRealKernel(anf_node)) {
continue;
}
KernelType kernel_type = AnfAlgo::GetKernelType(anf_node);

View File

@ -42,7 +42,7 @@ void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::K
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -57,7 +57,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -46,7 +46,7 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -40,7 +40,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) {
if (!double_in_eltwise_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(double_in_eltwise_input)) {
return;
}
if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput) &&
@ -51,7 +51,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
} else {
auto double_in_eltwise_input_1 = input_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1);
if (!double_in_eltwise_input_1->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input_1)) {
if (!double_in_eltwise_input_1->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(double_in_eltwise_input_1)) {
return;
}
if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input_1, prim::kPrimConv2DBackpropInput) &&
@ -68,7 +68,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const sess
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -30,7 +30,7 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
if (!eltwise_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;
}
@ -46,7 +46,7 @@ void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::Ke
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -43,7 +43,7 @@ void ConvBnReduceFusionPass::MatchSingleFusionPattern(const session::KernelGraph
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -38,7 +38,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) ||
if (!double_in_eltwise_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(double_in_eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) {
return;
}
@ -55,7 +55,7 @@ void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -42,7 +42,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
}
}
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
if (!eltwise_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;
}
@ -59,7 +59,7 @@ void ConvSingleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -54,7 +54,7 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -55,7 +55,7 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -25,7 +25,7 @@ bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph,
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
@ -40,7 +40,7 @@ bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
@ -55,7 +55,7 @@ bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &ker
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();

View File

@ -42,7 +42,7 @@ void MatmulConfusionTranposeFusionPass::MatchSingleFusionPattern(const session::
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -41,7 +41,7 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -59,7 +59,7 @@ void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -40,7 +40,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
}
}
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
if (!eltwise_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;
}
@ -72,7 +72,7 @@ void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -38,7 +38,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
}
}
MS_EXCEPTION_IF_NULL(eltwise_input);
if (!eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) ||
if (!eltwise_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(eltwise_input) ||
fusion_id_allocator->HasFusionIdAttr(eltwise_input)) {
return;
}
@ -69,7 +69,7 @@ void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGra
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -36,7 +36,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
write_input = input_cnode->input(kIndex1);
}
MS_EXCEPTION_IF_NULL(write_input);
if (!write_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(write_input) ||
if (!write_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(write_input) ||
fusion_id_allocator->HasFusionIdAttr(write_input)) {
return;
}
@ -49,7 +49,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
(void)record.insert(write_input);
auto conv_input = conv_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(conv_input);
if (!conv_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(conv_input) ||
if (!conv_input->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(conv_input) ||
fusion_id_allocator->HasFusionIdAttr(conv_input)) {
return;
}
@ -66,7 +66,7 @@ void StridedReadConvStridedWriteFusionPass::MatchSingleFusionPattern(const sessi
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}

View File

@ -210,7 +210,7 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph,
continue;
}
auto cnode = node->cast<CNodePtr>();
if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) {
if (AnfUtils::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) {
auto fusion_id = AnfAlgo::GetNodeAttr<int64_t>(cnode, kOpAttrFusionId);
(*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode);
}

View File

@ -28,7 +28,7 @@ const AnfNodePtr AddIoFormatAttrFor3DGraph::Process(const FuncGraphPtr &func_gra
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
if (AnfAlgo::IsRealKernel(node)) {
if (AnfUtils::IsRealKernel(node)) {
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto formats = AnfAlgo::GetAllOutputFormats(node);
if (std::any_of(formats.begin(), formats.end(),

View File

@ -27,7 +27,7 @@ bool InsertDependForAllGather::Run(const FuncGraphPtr &graph) {
std::map<int64_t, AnfNodePtr> all_gather_node;
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!node->cast<CNodePtr>() || !AnfAlgo::IsRealKernel(node)) {
if (!node->cast<CNodePtr>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();

View File

@ -62,7 +62,7 @@ const BaseRef InsertTensorMoveForGetNext::DefinePattern() const {
const AnfNodePtr InsertTensorMoveForGetNext::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (func_graph == nullptr || node == nullptr || !AnfAlgo::IsRealKernel(node)) {
if (func_graph == nullptr || node == nullptr || !AnfUtils::IsRealKernel(node)) {
return nullptr;
}

View File

@ -57,7 +57,7 @@ bool IsNodeOutPutUsedByOtherRealKernel(const FuncGraphPtr &graph, const AnfNodeP
for (const auto &node_pair : user_items) {
auto node = node_pair.first;
MS_EXCEPTION_IF_NULL(node);
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
if (AnfUtils::IsRealKernel(node) && !AnfAlgo::IsCommunicationOp(node)) {
MS_LOG(INFO) << "This node only used other real kernel: " << node->fullname_with_scope();
return true;
}

View File

@ -137,7 +137,7 @@ bool OutputCheck(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_LOG(INFO) << "Split has control edge, can not optimizer.";
return false;
}
if (AnfAlgo::IsRealKernel(item) && (AnfAlgo::GetProcessor(item) != 0)) {
if (AnfUtils::IsRealKernel(item) && (AnfAlgo::GetProcessor(item) != 0)) {
MS_LOG(INFO) << "Next node is not a AICore node, can not optimizer.";
return false;
}
@ -199,7 +199,7 @@ const AnfNodePtr SplitOpOptimizer::Process(const FuncGraphPtr &func_graph, const
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(func_graph);
if (!AnfAlgo::IsRealCNodeKernel(node)) {
if (!AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);

View File

@ -100,7 +100,7 @@ const BaseRef ChangeAxisOfReduceKernel::DefinePattern() const {
const AnfNodePtr ChangeAxisOfReduceKernel::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
return nullptr;
}
if (AnfAlgo::GetOpPattern(node) != kernel::kReducePattern) {

View File

@ -86,7 +86,7 @@ const BaseRef CheckConsistency::DefinePattern() const {
}
const AnfNodePtr CheckConsistency::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
return nullptr;
}

View File

@ -33,7 +33,7 @@ const BaseRef ConvertCastFormat::DefinePattern() const {
const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(node)) {
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}
auto node_name = AnfAlgo::GetCNodeName(node);

View File

@ -247,7 +247,7 @@ const AnfNodePtr DealRefAndSpiltUnSupportedTransdata::Process(const FuncGraphPtr
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!AnfAlgo::IsRealCNodeKernel(cnode)) {
if (!AnfUtils::IsRealCNodeKernel(cnode)) {
return nullptr;
}

View File

@ -125,7 +125,7 @@ const BaseRef InsertCast::DefinePattern() const {
const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
if (!AnfUtils::IsRealCNodeKernel(node) || func_graph == nullptr) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);

View File

@ -33,7 +33,7 @@ const BaseRef InsertTransOp::DefinePattern() const {
const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
if (node == nullptr || !AnfUtils::IsRealKernel(node)) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);

View File

@ -29,7 +29,7 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) {
for (auto &node : node_list) {
bool has_changed = false;
MS_EXCEPTION_IF_NULL(node);
if (!node->cast<CNodePtr>() || !AnfAlgo::IsRealKernel(node)) {
if (!node->cast<CNodePtr>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();

View File

@ -152,7 +152,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co
return nullptr;
}
MS_EXCEPTION_IF_NULL(next_node);
if (!next_node->isa<CNode>() || !AnfAlgo::IsRealKernel(next_node)) {
if (!next_node->isa<CNode>() || !AnfUtils::IsRealKernel(next_node)) {
return nullptr;
}
auto next_cnode = next_node->cast<CNodePtr>();
@ -208,7 +208,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
*output_idx = LongToSize(GetValue<int64_t>(value_ptr->value()));
*single_output = false;
}
return AnfAlgo::IsRealKernel(*prior_op);
return AnfUtils::IsRealKernel(*prior_op);
}
return false;
}

View File

@ -149,7 +149,7 @@ void RectifyDoMaskKernelInfo::ReSelecChildNodeKernelInfo(const CNodePtr &cnode,
for (const auto &out_node_info : *output_node_list) {
MS_EXCEPTION_IF_NULL(out_node_info.first);
auto out_node = out_node_info.first->cast<CNodePtr>();
if (AnfAlgo::IsRealKernel(out_node_info.first)) {
if (AnfUtils::IsRealKernel(out_node_info.first)) {
auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node);
kernel_selecter->SelectKernel(out_node);
auto new_build_info = AnfAlgo::GetSelectKernelBuildInfo(out_node);

View File

@ -77,7 +77,7 @@ void AddOutputs(const CNodePtr &cnode, const std::vector<size_t> &input_indices)
const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();

View File

@ -143,7 +143,7 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
for (auto node : node_list) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
if (node != nullptr && node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
CNodePtr cnode = node->cast<CNodePtr>();
InsertCast(func_graph, cnode);
}
@ -155,7 +155,7 @@ bool InsertCastCPU::Run(const FuncGraphPtr &func_graph) {
kernel::GetFuncGraphOutputNodes(func_graph, &outputs);
auto func_output = func_graph->output();
for (auto node : outputs) {
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
if (node != nullptr && node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
auto cnode = node->cast<CNodePtr>();
InsertCastForGraphOutput(func_graph, cnode, func_output);
}

View File

@ -101,7 +101,7 @@ void ProcessForTupleItem(const FuncGraphPtr &graph, const AnfNodePtr &node, int
// node->used_node, if output format of node equals input format of used_node,
// then no need to insert transpose between node and used_node.
auto used_node_in_format =
AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT;
AnfUtils::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT;
if (transpose_format == used_node_in_format) {
continue;
}
@ -159,8 +159,9 @@ void InsertTransformOpForOutput(const FuncGraphPtr &graph, const AnfNodePtr &nod
}
// node->used_node, if output format of node equals input format of used_node,
// then no need to insert transpose between node and used_node.
auto used_node_in_format =
AnfAlgo::IsRealCNodeKernel(used_node) ? AnfAlgo::GetInputFormat(used_node, used_node_index) : kOpFormat_DEFAULT;
auto used_node_in_format = AnfUtils::IsRealCNodeKernel(used_node)
? AnfAlgo::GetInputFormat(used_node, used_node_index)
: kOpFormat_DEFAULT;
if (outputs_format[i] == used_node_in_format) {
continue;
}
@ -169,7 +170,6 @@ void InsertTransformOpForOutput(const FuncGraphPtr &graph, const AnfNodePtr &nod
}
}
}
} // namespace
const std::unordered_set<std::string> kChannelLastKernel = {prim::kPrimBiasAdd->name()};
@ -181,7 +181,7 @@ bool InsertFormatTransformOpCPU::Run(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
for (auto node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node)) {
if (!AnfUtils::IsRealCNodeKernel(node)) {
continue;
}

View File

@ -35,7 +35,7 @@ bool AdjustDependForParallelOptimizerRecomputeAllGatherFusion::Run(const FuncGra
int64_t recompute_min_fusion_id = 0;
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!node->cast<CNodePtr>() || !AnfAlgo::IsRealKernel(node)) {
if (!node->cast<CNodePtr>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();

View File

@ -68,7 +68,7 @@ bool InsertCastGPU::Run(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
bool IsCasted = false;
for (auto node : node_list) {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
continue;
}

View File

@ -122,7 +122,7 @@ const AnfNodePtr InsertFormatTransformOp::Process(const FuncGraphPtr &graph, con
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
if (!AnfAlgo::IsRealCNodeKernel(node)) {
if (!AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}
auto iter = device::gpu::kKernelFormatPositionMap.find(AnfAlgo::GetCNodeName(node));

View File

@ -72,7 +72,7 @@ bool ReducePrecisionFusion::Run(const FuncGraphPtr &graph) {
std::vector<AnfNodePtr> node_list = TopoSort(graph->get_return());
for (auto node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (node != nullptr && node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
if (node != nullptr && node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
size_t input_num = AnfAlgo::GetInputTensorNum(node);
size_t output_num = AnfAlgo::GetOutputTensorNum(node);
for (size_t i = 0; i < input_num; i++) {

View File

@ -219,7 +219,7 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(mng);
for (const auto &n : todos) {
auto node = n->cast<CNodePtr>();
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) ||
if (node == nullptr || AnfAlgo::IsGraphKernel(node) || IsKeepBasicNode(node) || !AnfUtils::IsRealKernel(node) ||
!CanExpand(node)) {
continue;
}

View File

@ -530,7 +530,7 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
name << prefix << "_";
}
for (const auto &node : cnodes) {
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
name << AnfAlgo::GetCNodeName(node) << "_";
}
}
@ -655,7 +655,7 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
if (inputs.size() == 1) {
kernel_info->set_feature_map_flag(true);
}
if (AnfAlgo::IsRealKernel(cnode)) {
if (AnfUtils::IsRealKernel(cnode)) {
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
SetNodeAttrSafely(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);

View File

@ -574,7 +574,7 @@ class Splitter {
cnode->set_input(i, iter->second);
}
}
if (AnfAlgo::IsRealKernel(node)) {
if (AnfUtils::IsRealKernel(node)) {
inlined_nodes_.emplace_back(node);
}
}
@ -807,7 +807,7 @@ class CostModelSplitSchemer : public SplitSchemer {
virtual bool IsValidKernelNode(const AnfNodePtr &node) const {
if (!node->isa<CNode>()) return false;
if (AnfAlgo::IsRealKernel(node)) return true;
if (AnfUtils::IsRealKernel(node)) return true;
return false;
}

View File

@ -64,7 +64,7 @@ void RewriteOutputShape::Process(const AnfNodePtr &node, size_t index, const Abs
}
// do not process from real kernel
if (AnfAlgo::IsRealKernel(node)) return;
if (AnfUtils::IsRealKernel(node)) return;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() <= 1) return;

View File

@ -426,7 +426,7 @@ void MemReuseUtil::SetGraphOutputRefCount() {
kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
}
MS_EXCEPTION_IF_NULL(kernel_input.first);
if (!kernel_input.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_input.first)) {
if (!kernel_input.first->isa<CNode>() || !AnfUtils::IsRealKernel(kernel_input.first)) {
continue;
}
auto ak_node = kernel_input.first->cast<CNodePtr>();

View File

@ -235,7 +235,7 @@ void MemSwapManager::SaveUserKernelTopoOrder() {
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
for (auto &node_pair : node_set) {
auto user_kernel = node_pair.first;
if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) {
if (!AnfUtils::IsRealCNodeKernel(user_kernel)) {
continue;
}

View File

@ -31,7 +31,7 @@ bool AdjustDependForParallelOptimizerRecomputeAllGather::Run(const FuncGraphPtr
int64_t recompute_min_fusion_id = 0;
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!node->cast<CNodePtr>() || !AnfAlgo::IsRealKernel(node)) {
if (!node->cast<CNodePtr>() || !AnfUtils::IsRealKernel(node)) {
continue;
}
auto cnode = node->cast<CNodePtr>();

View File

@ -22,7 +22,7 @@ namespace mindspore {
namespace opt {
const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}

View File

@ -26,7 +26,7 @@ namespace mindspore {
namespace opt {
const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}

View File

@ -93,7 +93,7 @@ const BaseRef ConvertTupleInputToDynamicInput::DefinePattern() const {
const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
if (node == nullptr || !node->isa<CNode>() || !AnfUtils::IsRealKernel(node)) {
return nullptr;
}
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());

View File

@ -77,7 +77,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func
bool cnode_input_changed = false;
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
const auto &input = cnode->inputs()[i];
if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
if (input->Type() != nullptr && AnfUtils::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) &&
!AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) {
cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input));
cnode_input_changed = true;

View File

@ -25,7 +25,7 @@ namespace mindspore {
namespace opt {
const AnfNodePtr CustomOpConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
return nullptr;
}

View File

@ -35,7 +35,7 @@ CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vector<Kern
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
if (AnfUtils::IsRealCNodeKernel(cnode)) {
pass_vector->push_back(make_pair(cnode, IntToSize(1)));
return cnode;
}

View File

@ -526,7 +526,7 @@ void Somas::InitCommonNodeInputs(bool is_all_nop_node, const CNodePtr &kernel) {
MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple";
}
MS_EXCEPTION_IF_NULL(prenode_index.first);
if (!AnfAlgo::IsRealCNodeKernel(prenode_index.first)) {
if (!AnfUtils::IsRealCNodeKernel(prenode_index.first)) {
auto op_name = AnfAlgo::GetCNodeName(kernel);
TypeId input_origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(kernel, i);
if ((op_name == kDynamicRNNOpName || op_name == kDynamicGRUV2OpName) && input_origin_type == kMetaTypeNone) {
@ -814,7 +814,7 @@ void Somas::RefNodeProcess(const session::KernelGraph *graph) {
output_tensor->type_ = kRefNodeOutput;
total_output_size += size;
if (AnfAlgo::IsRealCNodeKernel(origin_pair.first)) {
if (AnfUtils::IsRealCNodeKernel(origin_pair.first)) {
auto ori_node = origin_pair.first->cast<CNodePtr>();
auto ori_index = origin_pair.second;
if (nodes_map_.find(ori_node.get()) == nodes_map_.end()) {

View File

@ -71,7 +71,7 @@ std::unordered_map<AnfNodePtr, NodeInfo> CollectNodeInfo(const FuncGraphPtr &fun
continue;
}
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
res[node] = NodeInfo(NodeType::kSupport, i);
continue;
}

View File

@ -53,7 +53,7 @@ bool TrtConverterContext::Parser() {
// Transform AnfNode To Trt layer.
// Bypass control node including Depend, Load, UpdateState, TupleGetItem, MakeTuple.
std::string op_name = AnfAlgo::GetCNodePrimitive(node)->name();
if (!AnfAlgo::IsRealKernel(node) && op_name != "Return") {
if (!AnfUtils::IsRealKernel(node) && op_name != "Return") {
continue;
}

View File

@ -57,31 +57,6 @@ constexpr size_t kMakeTupleInputStartPos = 1;
const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
return (prim && prim_set.find(prim) != prim_set.end());
}
bool IsRealKernelCNode(const CNodePtr &cnode) {
#ifndef ENABLE_SECURITY
static const PrimitiveSet virtual_prims = {
prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimReturn,
prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
#else
static const PrimitiveSet virtual_prims = {prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
prim::kPrimReturn, prim::kPrimPartial, prim::kPrimDepend,
prim::kPrimUpdateState, prim::kPrimLoad};
#endif
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << cnode->DebugString();
}
const auto &input = cnode->inputs().at(0);
bool is_virtual_node = IsOneOfPrimitive(input, virtual_prims);
return !is_virtual_node;
}
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
std::vector<size_t> shape_size_t;
@ -359,7 +334,7 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) {
return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, skip_nop_node, return_types);
}
if (AnfAlgo::IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
if (IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, skip_nop_node, return_types);
}
if (opt::IsNopNode(cnode) && skip_nop_node) {
@ -436,7 +411,7 @@ std::vector<KernelWithIndex> AnfRuntimeAlgorithm::GetAllOutputWithIndex(const An
}
size_t outputs_num = 1;
if (IsRealCNodeKernel(node)) {
if (AnfUtils::IsRealCNodeKernel(node)) {
outputs_num = AnfAlgo::GetOutputTensorNum(node);
}
// The output may be the tuple of node, so need visit all the outputs of node.
@ -661,7 +636,7 @@ size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
--input_num;
// Exclude monad inputs for real cnodes.
if (input_num > 0 && IsRealKernelCNode(cnode)) {
if (input_num > 0 && AnfUtils::IsRealKernel(cnode)) {
auto &inputs = cnode->inputs();
// Search monad inputs, backward.
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
@ -717,7 +692,7 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_
std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
@ -732,7 +707,7 @@ std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodeP
std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
@ -747,7 +722,7 @@ std::vector<std::string> AnfRuntimeAlgorithm::GetAllInputFormats(const AnfNodePt
std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputDeviceTypes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
@ -762,7 +737,7 @@ std::vector<TypeId> AnfRuntimeAlgorithm::GetAllInputDeviceTypes(const AnfNodePtr
std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputDeviceTypes(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
@ -777,7 +752,7 @@ std::vector<TypeId> AnfRuntimeAlgorithm::GetAllOutputDeviceTypes(const AnfNodePt
std::string AnfRuntimeAlgorithm::GetOriginDataFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
MS_LOG(EXCEPTION) << "Not real kernel:"
<< "#node [" << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
@ -798,7 +773,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
<< node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
@ -822,7 +797,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
<< node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
if (!IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
return GetPrevNodeOutputFormat(node, input_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
@ -967,7 +942,7 @@ std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, siz
<< node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
if (!IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, input_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
@ -987,7 +962,7 @@ std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, si
<< GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
if (!IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, output_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
@ -1041,7 +1016,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
<< GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
if (!IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, output_idx);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
@ -1064,7 +1039,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
<< GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"
<< " trace: " << trace::DumpSourceLines(node);
}
if (!IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, 0);
}
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
@ -1458,38 +1433,10 @@ void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *
kernel_info->set_kernel_mod(kernel_mod);
}
bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// parameter and value node is a real kernel too
if (!node->isa<CNode>()) {
return true;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
<< " trace: " << trace::DumpSourceLines(node);
}
return IsRealKernelCNode(cnode);
}
bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// parameter and value node is not a real cnode kernel
if (!node->isa<CNode>()) {
return false;
}
// return considered as a real node
if (CheckPrimitiveType(node, prim::kPrimReturn)) {
return true;
}
return IsRealKernel(node);
}
bool AnfRuntimeAlgorithm::IsGraphKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// graph kernel should be a real cnode kernel.
if (!IsRealCNodeKernel(node)) {
if (!AnfUtils::IsRealCNodeKernel(node)) {
return false;
}
@ -2248,9 +2195,8 @@ void AnfRuntimeAlgorithm::GetAllVisitedCNode(const CNodePtr &anf_node, std::vect
if (!input->isa<CNode>()) {
continue;
}
auto input_cnode = input->cast<CNodePtr>();
if (!IsRealKernelCNode(input_cnode) || opt::IsNopNode(input_cnode)) {
GetAllVisitedCNode(input_cnode, used_kernels, visited);
if (!AnfUtils::IsRealKernel(input) || opt::IsNopNode(input)) {
GetAllVisitedCNode(input->cast<CNodePtr>(), used_kernels, visited);
} else {
used_kernels->push_back(input);
}
@ -2267,7 +2213,7 @@ void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
if (AnfUtils::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
@ -2395,15 +2341,6 @@ bool AnfRuntimeAlgorithm::IsTensorBroadcast(const std::vector<size_t> &lhs, cons
return false;
}
bool AnfRuntimeAlgorithm::IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() == 0) {
return false;
}
return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
}
bool AnfRuntimeAlgorithm::IsControlOpExecInBackend(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;

View File

@ -35,11 +35,11 @@
#include "backend/kernel_compiler/kernel_build_info.h"
#include "base/core_ops.h"
#include "utils/contract.h"
#include "utils/anf_utils.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace session {
using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
@ -236,10 +236,6 @@ class AnfRuntimeAlgorithm {
static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
// set kernel mod
static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
// checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
static bool IsRealKernel(const AnfNodePtr &node);
// checkout whether the anf node is a real kernel that is a cnode and can run on device
static bool IsRealCNodeKernel(const AnfNodePtr &node);
// checkout whether the anf node is a graph kernel.
static bool IsGraphKernel(const AnfNodePtr &node);
// checkout whether the anf node is an inner node of graph kernel.
@ -321,7 +317,6 @@ class AnfRuntimeAlgorithm {
}
return result;
}
static bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
// Judge a control operator need be compiled into kernel graph rather than be cut into single op and
// executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch

View File

@ -759,7 +759,7 @@ class AscendAutoMonadConverter {
end_iter = std::find(nodes.rbegin(), nodes.rend(), end_node);
}
for (auto iter = nodes.rbegin(); iter != end_iter; ++iter) {
if (!AnfAlgo::IsRealCNodeKernel(*iter)) {
if (!AnfUtils::IsRealCNodeKernel(*iter)) {
continue;
}
if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelSet)) {
@ -789,7 +789,7 @@ class AscendAutoMonadConverter {
end_iter = std::find(nodes.rbegin(), nodes.rend(), end_node);
}
for (auto iter = nodes.rbegin(); iter != end_iter; ++iter) {
if (!AnfAlgo::IsRealCNodeKernel(*iter)) {
if (!AnfUtils::IsRealCNodeKernel(*iter)) {
continue;
}
if (AnfAlgo::CheckPrimitiveType(*iter, prim::kPrimLabelGoto) && AnfAlgo::HasNodeAttr(kAttrReturn, *iter)) {

View File

@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() {
}
// add execute node
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>() && AnfAlgo::IsRealKernel(node)) {
if (node->isa<CNode>() && AnfUtils::IsRealKernel(node)) {
execution_order_.push_back(node->cast<CNodePtr>());
}
// delay execute comm ops that need optimize
@ -527,7 +527,7 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
if (AnfAlgo::GetInputTensorNum(node) == 0) {
kernel_info->set_feature_map_flag(true);
}
if (AnfAlgo::IsRealKernel(node)) {
if (AnfUtils::IsRealKernel(node)) {
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);

View File

@ -75,7 +75,7 @@ bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodeP
(AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
return false;
}
if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
if (AnfUtils::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
return true;
}
(*idx) += 1;
@ -564,7 +564,7 @@ AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, Kernel
auto parameters = AnfAlgo::GetAllOutput(new_parameter);
std::vector<AnfNodePtr> pre_graph_out = {node};
// If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
if (!pre_graph_out.empty() && !AnfUtils::IsRealKernel(node)) {
pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
}
@ -1832,7 +1832,7 @@ void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(node);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
if (!AnfUtils::IsRealKernel(item_with_index.first)) {
MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
}
summary[n->fullname_with_scope()] = item_with_index;
@ -1935,7 +1935,7 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
if (!front_node->isa<CNode>()) {
return nullptr;
}
if (AnfAlgo::IsRealKernel(front_node)) {
if (AnfUtils::IsRealKernel(front_node)) {
return front_node;
}
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
@ -1967,7 +1967,7 @@ bool IsUnusedInternlOutput(const AnfNodePtr &user) {
if (IsPrimitiveCNode(user, prim::kPrimSwitch) || IsPrimitiveCNode(user, prim::kPrimSwitchLayer)) {
return true;
}
if (!AnfAlgo::IsRealKernel(user)) {
if (!AnfUtils::IsRealKernel(user)) {
return true;
}
return false;
@ -2001,7 +2001,7 @@ std::string SessionBasic::AddPartialParametersMap(const AnfNodePtr &partial_node
if (!node->isa<CNode>()) {
continue;
}
if (!AnfAlgo::IsRealKernel(node)) {
if (!AnfUtils::IsRealKernel(node)) {
continue;
}
std::string cur_target = GetCNodeTarget(node);

View File

@ -108,7 +108,7 @@ VectorRef MultiGraphAclSession::ConstructOutputRef(GraphId graph_id, std::deque<
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else if (AnfAlgo::IsRealKernel(anf_node)) {
} else if (AnfUtils::IsRealKernel(anf_node)) {
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << out->DebugString()
<< ", visited: " << anf_node->DebugString();
@ -139,7 +139,7 @@ VectorRef MultiGraphAclSession::ConstructOutputRefByTupleNode(const CNodePtr &tu
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
outs.emplace_back(ConstructOutputRefByTupleNode(cnode, out_tensors));
} else if (AnfAlgo::IsRealKernel(anf_node)) {
} else if (AnfUtils::IsRealKernel(anf_node)) {
if (out_tensors->empty()) {
MS_LOG(EXCEPTION) << "Can not find MSTensor for output node " << tuple_node->input(i)->DebugString()
<< ", visited: " << anf_node->DebugString();

View File

@ -1238,7 +1238,7 @@ bool AscendStreamAssign::IsAllOutGraphOut(const KernelGraphPtr &graph, const CNo
for (const auto &node : nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(item_with_index.first);
if (!item_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(item_with_index.first)) {
if (!item_with_index.first->isa<CNode>() || !AnfUtils::IsRealKernel(item_with_index.first)) {
continue;
}
if (item_with_index.first == cnode) {

View File

@ -71,7 +71,7 @@ static bool KernelBuildParallelCompile(const std::vector<CNodePtr> &kernels) {
std::vector<AnfNodePtr> other_nodes;
for (const auto &anf_node : kernels) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!AnfAlgo::IsRealKernel(anf_node)) {
if (!AnfUtils::IsRealKernel(anf_node)) {
continue;
}
if (AnfAlgo::GetKernelMod(anf_node) != nullptr) {

View File

@ -190,7 +190,7 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i
auto prenode_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
MS_EXCEPTION_IF_NULL(prenode_with_index.first);
if (AnfAlgo::IsRealCNodeKernel(prenode_with_index.first)) {
if (AnfUtils::IsRealCNodeKernel(prenode_with_index.first)) {
if ((AnfAlgo::GetCNodeName(prenode_with_index.first) == kSplitOpName ||
AnfAlgo::GetCNodeName(prenode_with_index.first) == kSplitVOpName) &&
AnfAlgo::HasNodeAttr(kAttrNonTask, prenode_with_index.first->cast<CNodePtr>())) {

View File

@ -263,7 +263,7 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
if (!FormatTransformChecker::GetInstance().format_transform()) {
return false;
}
if (!AnfAlgo::IsRealCNodeKernel(kernel_node)) {
if (!AnfUtils::IsRealCNodeKernel(kernel_node)) {
return false;
}
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);

View File

@ -633,7 +633,7 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph)
for (const auto &node : nodes) {
auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
MS_EXCEPTION_IF_NULL(kernel_with_index.first);
if (!kernel_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_with_index.first)) {
if (!kernel_with_index.first->isa<CNode>() || !AnfUtils::IsRealKernel(kernel_with_index.first)) {
continue;
}
if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {

View File

@ -102,7 +102,7 @@ bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph) {
bool IsKernelActor(const AnfNodePtr &node, GraphExecutionStrategy strategy) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node)) {
if (!AnfUtils::IsRealCNodeKernel(node)) {
return false;
}

View File

@ -797,7 +797,7 @@ void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
auto input_node = AnfAlgo::GetInputNode(kernel, i);
// Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
}
if (HasAbstractMonad(input_node)) {
@ -1069,7 +1069,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const
real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
// The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
// node in this scene.
if (AnfAlgo::IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
if (IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
}
} else if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
@ -1101,7 +1101,7 @@ void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const
}
// The monad node and make tuple node need recursion.
if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph);
continue;
}

View File

@ -588,4 +588,18 @@ bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
}
return false;
}
bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
return (prim && prim_set.find(prim) != prim_set.end());
}
bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() == 0) {
return false;
}
return IsOneOfPrimitive(cnode->input(0), prim_set);
}
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include <vector>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <set>
@ -73,6 +74,9 @@ using FuncGraphVector = std::vector<FuncGraphPtr>;
class Primitive;
using PrimitivePtr = std::shared_ptr<Primitive>;
class PrimitiveHasher;
class PrimitiveEqual;
using PrimitiveSet = std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual>;
class BaseRef;
@ -1056,6 +1060,16 @@ PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
/// \return Whether the given node is a ValueNode with some Primitive value.
MS_CORE_API bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value);
// Check whether the given node is a ValueNode belonging to a primitive set.
bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set);
/// \brief Used to check whether the given node is a CNode belonging to a primitive set.
///
/// \param[in] node The input node.
/// \param[in] prim_set Primitive set.
/// \return Whether the given node is a CNode belonging to a primitive set.
MS_CORE_API bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
// Check whether two primitives are same.
bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2);

View File

@ -15,6 +15,9 @@
*/
#include "utils/anf_utils.h"
#include "base/core_ops.h"
#include "utils/trace_base.h"
#include "utils/utils.h"
namespace mindspore {
bool AnfUtils::IsDimUnknown(const abstract::ShapePtr &shape) {
@ -79,4 +82,39 @@ bool AnfUtils::IsDimUnknown(const AnfNodePtr &node) {
}
return false;
}
bool AnfUtils::IsRealKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
#ifndef ENABLE_SECURITY
static const PrimitiveSet virtual_prims = {
prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary,
prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimReturn,
prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
#else
static const PrimitiveSet virtual_prims = {prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
prim::kPrimReturn, prim::kPrimPartial, prim::kPrimDepend,
prim::kPrimUpdateState, prim::kPrimLoad};
#endif
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
// parameter and value node is a real kernel too
return true;
}
if (cnode->size() == 0) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
<< " trace: " << trace::DumpSourceLines(node);
}
return !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims);
}
bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
return true;
}
return AnfUtils::IsRealKernel(node);
}
} // namespace mindspore

View File

@ -30,7 +30,10 @@ class AnfUtils {
static bool IsShapeDynamic(const std::vector<size_t> &shape);
static bool IsNodeOutputDynamicShape(const CNodePtr &node);
static bool IsDimUnknown(const AnfNodePtr &node);
// check whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
static bool IsRealKernel(const AnfNodePtr &node);
// check whether the anf node is a real kernel that is a cnode and can run on device
static bool IsRealCNodeKernel(const AnfNodePtr &node);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_UTILS_ANF_UTILS_H_

View File

@ -70,14 +70,14 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_single_output) {
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
for (auto& node : node_list) {
for (auto &node : node_list) {
if (node == nullptr) {
continue;
}
if (node->isa<Parameter>()) {
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), node.get());
} else if (node != func_graph->get_return() && AnfAlgo::IsRealKernel(node)) {
} else if (node != func_graph->get_return() && AnfUtils::IsRealKernel(node)) {
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
}
@ -124,14 +124,14 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_multiple_output) {
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
for (auto& node : node_list) {
for (auto &node : node_list) {
if (node == nullptr) {
continue;
}
if (node->isa<Parameter>()) {
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), node.get());
} else if (node != func_graph->get_return() && AnfAlgo::IsRealKernel(node)) {
} else if (node != func_graph->get_return() && AnfUtils::IsRealKernel(node)) {
node->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), node.get());
}

View File

@ -372,8 +372,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) {
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_NZ});
builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id(),
kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
d_kernel_info->set_select_kernel_build_info(builder.Build());
EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 0)[2], 224);
EXPECT_EQ(AnfAlgo::GetOutputDeviceShape(add, 1)[0], 2);
@ -722,42 +721,42 @@ TEST_F(AnfRuntimeAlgorithmTest, IsRealKernel) {
auto kernel_graph = std::make_shared<KernelGraph>();
// test value node as input
auto value_node = NewValueNode(prim::kPrimAdd);
EXPECT_TRUE(AnfAlgo::IsRealKernel(value_node));
EXPECT_THROW(AnfAlgo::IsRealKernel(nullptr), std::runtime_error);
EXPECT_TRUE(AnfUtils::IsRealKernel(value_node));
EXPECT_THROW(AnfUtils::IsRealKernel(nullptr), std::runtime_error);
// test parameter as input
auto parameter_node = kernel_graph->add_parameter();
EXPECT_TRUE(AnfAlgo::IsRealKernel(parameter_node));
EXPECT_TRUE(AnfUtils::IsRealKernel(parameter_node));
// test add as input
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimAdd));
auto add = kernel_graph->NewCNode(inputs);
EXPECT_TRUE(AnfAlgo::IsRealKernel(add));
EXPECT_TRUE(AnfUtils::IsRealKernel(add));
// test Depend as input
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimDepend));
auto depend_node = kernel_graph->NewCNode(inputs);
EXPECT_FALSE(AnfAlgo::IsRealKernel(depend_node));
EXPECT_FALSE(AnfUtils::IsRealKernel(depend_node));
}
TEST_F(AnfRuntimeAlgorithmTest, IsRealCNodeKernel) {
auto kernel_graph = std::make_shared<KernelGraph>();
// test value node as input
auto value_node = NewValueNode(prim::kPrimAdd);
EXPECT_FALSE(AnfAlgo::IsRealCNodeKernel(value_node));
EXPECT_THROW(AnfAlgo::IsRealCNodeKernel(nullptr), std::runtime_error);
EXPECT_FALSE(AnfUtils::IsRealCNodeKernel(value_node));
EXPECT_THROW(AnfUtils::IsRealCNodeKernel(nullptr), std::runtime_error);
// test parameter as input
auto parameter_node = kernel_graph->add_parameter();
EXPECT_FALSE(AnfAlgo::IsRealCNodeKernel(parameter_node));
EXPECT_FALSE(AnfUtils::IsRealCNodeKernel(parameter_node));
// test add as input
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimAdd));
auto add = kernel_graph->NewCNode(inputs);
EXPECT_TRUE(AnfAlgo::IsRealCNodeKernel(add));
EXPECT_TRUE(AnfUtils::IsRealCNodeKernel(add));
// test ImageSummary as input
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimDepend));
auto depend = kernel_graph->NewCNode(inputs);
EXPECT_FALSE(AnfAlgo::IsRealCNodeKernel(depend));
EXPECT_FALSE(AnfUtils::IsRealCNodeKernel(depend));
}
TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) {