diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index a8fd7dc5144..1840966358a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -21,6 +21,7 @@ #include #include "device/ascend/kernel_select_ascend.h" #include "kernel/kernel_query.h" +#include "kernel/tbe/tbe_kernel_select.h" namespace mindspore { namespace opt { @@ -36,6 +37,16 @@ class KernelSelect { }; using KernelSelectPtr = std::shared_ptr; +class SupportedChecker { + public: + SupportedChecker() = default; + virtual ~SupportedChecker() = default; + virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) { + return kernel::CheckSupported(anf_node, select_kernel_build_info); + } +}; +using SupportedCheckerPtr = std::shared_ptr; + class KernelQuery { public: KernelQuery() = default; diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc index 5924f6cd1cb..4bdd5f03825 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.cc @@ -16,6 +16,9 @@ #include "pre_activate/ascend/ir_fission/topk_split.h" #include #include +#include +#include "pre_activate/common/helper.h" +#include "kernel/kernel_build_info.h" #include "utils/utils.h" #include "session/kernel_graph.h" #include "session/anf_runtime_algorithm.h" @@ -25,6 +28,7 @@ namespace mindspore { namespace opt { constexpr size_t kFloat16Len = 2; // size of float16; +constexpr size_t kTopkIndexK = 1; namespace { tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { // 1 create tensor @@ -70,37 +74,68 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) { AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get()); return indices_const; } + +kernel::KernelBuildInfoPtr CreateKernelBuildInfo() { + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); + builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16}); + builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32}); + return builder.Build(); +} } // namespace const BaseRef TopKSplit::DefinePattern() const { - VarPtr X = std::make_shared(); - MS_EXCEPTION_IF_NULL(X); + VarPtr X1 = std::make_shared(); + VarPtr X2 = std::make_shared(); auto prim = std::make_shared(kTopKOpName); - MS_EXCEPTION_IF_NULL(prim); - return VectorRef({prim, X}); + return VectorRef({prim, X1, X2}); } const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); auto kernel_graph = func_graph->cast(); - auto indices_const = CreateValueNode(node); // set value node as topk's input auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - MS_LOG(INFO) << "already has input size: " << cnode->inputs().size(); - cnode->add_input(indices_const); + // Copy a new node to check supported. + std::vector new_inputs{NewValueNode(std::make_shared(kTopKOpName))}; + new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + CNodePtr new_cnode = func_graph->NewCNode(new_inputs); + MS_EXCEPTION_IF_NULL(new_cnode); + new_cnode->set_abstract(cnode->abstract()); + new_cnode->set_scope(cnode->scope()); + AnfAlgo::CopyNodeAttrs(cnode, new_cnode); + CheckCNodeInputSize(new_cnode, kTopkInputNum); + // Convert the tensor input to scalar and convert it to attr + auto input_k = new_cnode->input(kTopkIndexK + 1); + MS_EXCEPTION_IF_NULL(input_k); + if (!IsValueNode(input_k)) { + return nullptr; + } + ValuePtr value = GetValueNode(input_k); + MS_EXCEPTION_IF_NULL(value); + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + int32_t *data = reinterpret_cast(tensor->data_c()); + MS_EXCEPTION_IF_NULL(data); + auto new_value_node = std::make_shared(MakeValue(*data)); + new_cnode->set_input(kTopkIndexK + 1, new_value_node); + + std::unordered_set attr_index{kTopkIndexK}; + ConstInputToAttr(new_cnode, attr_index); + auto indices_const = CreateValueNode(new_cnode); + new_cnode->add_input(indices_const); + MS_EXCEPTION_IF_NULL(supported_checker_); + if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) { + return nullptr; + } + if (kernel_graph != nullptr) { kernel_graph->AddValueNodeToGraph(indices_const); } - CNodePtr new_cnode = nullptr; - if (kernel_graph == nullptr) { - new_cnode = std::make_shared(*cnode); - } else { - new_cnode = kernel_graph->NewCNode(cnode); - } - MS_EXCEPTION_IF_NULL(new_cnode); return new_cnode; } } // namespace opt diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h index 8fcbbac4755..e7293e1fa39 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/topk_split.h @@ -16,15 +16,22 @@ #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_ +#include #include "pre_activate/common/optimizer.h" +#include "pre_activate/ascend/ascend_helper.h" + namespace mindspore { namespace opt { class TopKSplit : public PatternProcessPass { public: - explicit TopKSplit(bool multigraph = true) : PatternProcessPass("topk_split", multigraph) {} + explicit TopKSplit(bool multigraph = true) + : PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared()) {} ~TopKSplit() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + SupportedCheckerPtr supported_checker_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index de452392683..9e8187ffb23 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -422,5 +422,47 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get()); return tuple_getitem; } + +void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector new_inputs; + std::vector new_input_names; + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + auto input_names = primitive->GetAttr(kAttrInputNames); + if (input_names == nullptr) { + MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; + return; + } + auto input_names_vec = GetValue>(input_names); + auto inputs = cnode->inputs(); + new_inputs.push_back(inputs[0]); + bool need_update = false; + for (size_t i = 0; i < inputs.size() - 1; ++i) { + auto input_node = inputs[i + 1]; + MS_EXCEPTION_IF_NULL(input_node); + if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { + auto value_node = input_node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; + if (i >= input_names_vec.size()) { + MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; + } + primitive->set_attr(input_names_vec[i], value_node->value()); + need_update = true; + } else { + new_inputs.push_back(input_node); + if (i < input_names_vec.size()) { + new_input_names.push_back(input_names_vec[i]); + } + } + } + if (need_update) { + // Update cnode's inputs + cnode->set_inputs(new_inputs); + // Update cnode's input_names attr + primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index 04a4dd6c81f..9ef57d8e7cb 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "ir/func_graph.h" #include "session/kernel_graph.h" #include "common/utils.h" @@ -86,6 +87,7 @@ constexpr size_t kAdamApplyOneOutputNum = 3; constexpr size_t kBackendTransDataInputNum = 2; constexpr size_t kApplyMomentumInputNum = 6; constexpr size_t kBiasAddInputNum = 3; +constexpr size_t kTopkInputNum = 3; enum FusedBatchNormInput { kX = 1, @@ -150,6 +152,8 @@ void RemoveNopNode(session::KernelGraph *const graph); AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); + +void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_ diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc index fb47c9fc2ac..0b4263685b5 100644 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc @@ -52,7 +52,6 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(kFlattenGradOpName, {1}); Register(kExpandDimsOpName, {1}); Register(kSplitOpName, {0}); - Register(kTopKOpName, {1}); Register(kErfOpName, {1}); Register(kSparseApplyAdagradOpName, {2}); Register(kResizeNearestNeighborGrad, {1}); diff --git a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc index 15d62a164fb..1f9e2712a6a 100644 --- a/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc @@ -18,10 +18,10 @@ #include #include #include -#include #include #include "pre_activate/pass/const_input_to_attr_registry.h" +#include "pre_activate/common/helper.h" #include "utils/utils.h" #include "utils/context/ms_context.h" #include "operator/ops.h" @@ -29,50 +29,6 @@ namespace mindspore { namespace opt { -namespace { -void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &input_attrs) { - MS_EXCEPTION_IF_NULL(cnode); - std::vector new_inputs; - std::vector new_input_names; - auto primitive = AnfAlgo::GetCNodePrimitive(cnode); - MS_EXCEPTION_IF_NULL(primitive); - auto input_names = primitive->GetAttr(kAttrInputNames); - if (input_names == nullptr) { - MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]"; - return; - } - auto input_names_vec = GetValue>(input_names); - auto inputs = cnode->inputs(); - new_inputs.push_back(inputs[0]); - bool need_update = false; - for (size_t i = 0; i < inputs.size() - 1; ++i) { - auto input_node = inputs[i + 1]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { - auto value_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]"; - if (i >= input_names_vec.size()) { - MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]"; - } - primitive->set_attr(input_names_vec[i], value_node->value()); - need_update = true; - } else { - new_inputs.push_back(input_node); - if (i < input_names_vec.size()) { - new_input_names.push_back(input_names_vec[i]); - } - } - } - if (need_update) { - // Update cnode's inputs - cnode->set_inputs(new_inputs); - // Update cnode's input_names attr - primitive->set_attr(kAttrInputNames, MakeValue(new_input_names)); - } -} -} // namespace - const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index 94fa04ef7af..43ddc046b7c 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -17,8 +17,13 @@ #include "common/backend_common_test.h" #include "common/py_func_graph_fetcher.h" #include "device/kernel_info.h" -#include "pre_activate/ascend/ir_fission/topk_split.h" +#include "pre_activate/pass/convert_const_input_to_attr.h" #include "debug/anf_ir_dump.h" +#define private public +#define protected public +#include "pre_activate/ascend/ir_fission/topk_split.h" +#undef private +#undef protected namespace mindspore { namespace opt { @@ -30,6 +35,15 @@ class TestHWTopKSplit : public BackendCommon { UT::PyFuncGraphFetcher get_py_fun_; }; +class MockSupportedChecker : public SupportedChecker { + public: + MockSupportedChecker() = default; + ~MockSupportedChecker() override = default; + bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override { + return true; + } +}; // namespace opt + TEST_F(TestHWTopKSplit, test_topk_split) { /* * def before(input): @@ -40,19 +54,25 @@ TEST_F(TestHWTopKSplit, test_topk_split) { FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before"); std::vector shp{4, 4}; auto x_abstract = std::make_shared(kFloat32, shp); - g->parameters()[0]->set_abstract(x_abstract); - auto ret = g->get_return(); - EXPECT_NE(ret, nullptr); - auto tuple_getitem = ret->input(1); - EXPECT_NE(tuple_getitem, nullptr); - auto topk = tuple_getitem->cast()->input(1); - topk->set_abstract(x_abstract); + AbstractBasePtrList args_spec_list{x_abstract}; + auto kernel_graph = GetKernelGraph(g, args_spec_list); auto optimizer = std::make_shared(); auto pm = std::make_shared(); - pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); + auto topk_split = std::make_shared(); + topk_split->supported_checker_ = std::make_shared(); + pm->AddPass(topk_split); optimizer->AddPassManager(pm); - FuncGraphPtr new_graph = optimizer->Optimize(g); + FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph); + + auto ret = new_graph->get_return(); + EXPECT_NE(ret, nullptr); + auto make_tuple = ret->input(1); + EXPECT_NE(make_tuple, nullptr); + auto tuple_getitem = make_tuple->cast()->input(1); + EXPECT_NE(tuple_getitem, nullptr); + auto topk = tuple_getitem->cast()->input(1); auto topk_cnode = topk->cast(); EXPECT_EQ(topk_cnode->inputs().size(), 3); EXPECT_TRUE(topk_cnode->input(2)->isa()); diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/topk_split_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/topk_split_test.py index 4cdbfa084e7..c1734198978 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/topk_split_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/topk_split_test.py @@ -35,7 +35,7 @@ def test_topk_split(tag): @fns def before(input): - topk = TopK(input) + topk = TopK(input, 2) output = tuple_getitem(topk, 0) return output