From 9808e4766352ccf0ef93ddbd8adf199b91ae7234 Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Tue, 9 Jun 2020 09:31:43 +0800 Subject: [PATCH] change checkAicpu to CheckAICPU & add charge Scalar function to charge the input or output is scalar --- .../ccsrc/pre_activate/ascend/ascend_helper.h | 4 ++-- .../convert_unsupported_transnode_to_aicpu.cc | 8 ++++---- .../pre_activate/ascend/ir_fission/topk_split.cc | 2 +- .../ir_fusion/transpose_transdata_fusion.cc | 2 +- mindspore/ccsrc/session/anf_runtime_algorithm.cc | 16 ++++++++++++++++ mindspore/ccsrc/session/anf_runtime_algorithm.h | 2 ++ mindspore/ccsrc/utils/utils.h | 2 ++ .../ascend/ir_fission/topk_split_test.cc | 3 +-- .../ir_fusion/transpose_transdata_fusion_test.cc | 2 +- 9 files changed, 30 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index 66e3f2ad330..ee0d837cee4 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -37,11 +37,11 @@ class SupportedChecker { public: SupportedChecker() = default; virtual ~SupportedChecker() = default; - virtual bool CheckAiCoreSupported(const AnfNodePtr &anf_node, + virtual bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { return kernel::IsSupportedByAICore(anf_node, select_kernel_build_info); } - virtual bool CheckAiCpuSupported(const AnfNodePtr &anf_node, + virtual bool CheckAICPUSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { return kernel::IsSupportedByAICPU(anf_node, select_kernel_build_info); } diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc index 5b5bf7e4fcb..cfa4e423424 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/convert_unsupported_transnode_to_aicpu.cc @@ -38,9 +38,9 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph return nullptr; } auto kernel_builder_info = AnfAlgo::GetSelectKernelBuildInfo(node); - if (supported_checker_->CheckAiCoreSupported(node, kernel_builder_info)) { - return node; - } else if (supported_checker_->CheckAiCpuSupported(node, kernel_builder_info)) { + if (supported_checker_->CheckAICoreSupported(node, kernel_builder_info)) { + return nullptr; + } else if (supported_checker_->CheckAICPUSupported(node, kernel_builder_info)) { auto builder = std::make_shared(kernel_builder_info); builder->SetKernelType(AICPU_KERNEL); AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); @@ -49,7 +49,7 @@ const AnfNodePtr ConvertUnSupportNodeToAICPU::Process(const mindspore::FuncGraph MS_LOG(EXCEPTION) << " kernel " << kernel_builder_info->ToString() << "is not supported in AiCPU & AiCore : node [" << node->DebugString() << "]"; } - return node; + return nullptr; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index 9abef8fa703..95bcb9f210b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -148,7 +148,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod auto indices_const = CreateValueNode(new_cnode); new_cnode->add_input(indices_const); MS_EXCEPTION_IF_NULL(supported_checker_); - if (!supported_checker_->CheckAiCoreSupported(new_cnode, CreateKernelBuildInfo())) { + if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { MS_LOG(INFO) << "split topk failed, check to aicpu."; return nullptr; } diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc index 16517187032..e45fc2637fe 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transpose_transdata_fusion.cc @@ -53,7 +53,7 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap new_transdata_builder->SetProcessor(transdata_kernel_build_info->processor()); auto new_fusion_transdata = std::make_shared(kTransDataOpName); - if (supported_checker_->CheckAiCoreSupported(transdata_cnode, new_transdata_builder->Build())) { + if (supported_checker_->CheckAICoreSupported(transdata_cnode, new_transdata_builder->Build())) { std::vector inputs = {NewValueNode(new_fusion_transdata), utils::cast((*equiv)[input_varptr_])}; auto new_node = func_graph->NewCNode(inputs); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 6cc68457e5d..09ea32becba 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -976,5 +976,21 @@ bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) { } MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString(); } + +bool AnfRuntimeAlgorithm::IsScalarInput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} + +bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); + if (shape.empty()) { + return true; + } + return shape.size() == kShape1dDims && shape[0] == 1; +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 10ae5282e0a..bab867a3ef8 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -185,6 +185,8 @@ class AnfRuntimeAlgorithm { static FuncGraphPtr GetValueNodeFuncGraph(const AnfNodePtr &node); static std::vector GetCallNodeKernelGraph(const CNodePtr &call_node); static bool IsSwitchCall(const CNodePtr &call_node); + static bool IsScalarInput(const CNodePtr &cnode, size_t index); + static bool IsScalarOutput(const CNodePtr &cnode, size_t index); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index ff2ba05c841..b2771f4b9b7 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -207,7 +207,9 @@ constexpr auto kValueTargetOther = "target_other"; // some size const size_t kShape4dDims = 4; +const size_t kShape2dDims = 2; const size_t kShape5dDims = 5; +const size_t kShape1dDims = 1; const size_t kCubeSize = 16; const size_t kMemAlignSize = 512; const int kParameterDataTensorMask = 0; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index 4cee3577ed4..b09268aa662 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -55,8 +55,7 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckAiCoreSupported(const AnfNodePtr &anf_node, - const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } }; // namespace opt diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index 8bb9de7c7d4..98dc9e9efc3 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -42,7 +42,7 @@ class MockSupportedChecker : public SupportedChecker { public: MockSupportedChecker() = default; ~MockSupportedChecker() override = default; - bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + bool CheckAICoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { return true; } };