forked from mindspore-Ecosystem/mindspore
!1554 Check the size of topk input names before converting input to attr
Merge pull request !1554 from YuJianfeng/master
This commit is contained in:
commit
ab94e92cd1
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/ascend/ir_fission/topk_split.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
|
@ -102,6 +103,11 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
|||
// set value node as topk's input
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
|
||||
if (input_names_vec.size() < kTopkIndexK + 1) {
|
||||
MS_LOG(INFO) << "The input k of topk has been converted to attr";
|
||||
return nullptr;
|
||||
}
|
||||
// Copy a new node to check supported.
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
|
||||
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "device/kernel_info.h"
|
||||
#include "pre_activate/pass/convert_const_input_to_attr.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "pre_activate/ascend/ir_fission/topk_split.h"
|
||||
|
@ -32,6 +33,21 @@ class TestHWTopKSplit : public BackendCommon {
|
|||
TestHWTopKSplit() : get_py_fun_("gtest_input.pre_activate.topk_split_test", true) {}
|
||||
~TestHWTopKSplit() override = default;
|
||||
|
||||
CNodePtr GetTopkCNodeFromKernelGraph(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto ret = func_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto make_tuple = ret->input(1);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
|
||||
MS_EXCEPTION_IF_NULL(topk);
|
||||
auto topk_cnode = topk->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(topk_cnode);
|
||||
return topk_cnode;
|
||||
}
|
||||
|
||||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
|
@ -39,7 +55,8 @@ class MockSupportedChecker : public SupportedChecker {
|
|||
public:
|
||||
MockSupportedChecker() = default;
|
||||
~MockSupportedChecker() override = default;
|
||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
|
||||
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
return true;
|
||||
}
|
||||
}; // namespace opt
|
||||
|
@ -66,14 +83,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
|
|||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
|
||||
|
||||
auto ret = new_graph->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
auto make_tuple = ret->input(1);
|
||||
EXPECT_NE(make_tuple, nullptr);
|
||||
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(tuple_getitem, nullptr);
|
||||
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
|
||||
auto topk_cnode = topk->cast<CNodePtr>();
|
||||
auto topk_cnode = GetTopkCNodeFromKernelGraph(new_graph);
|
||||
EXPECT_EQ(topk_cnode->inputs().size(), 3);
|
||||
EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>());
|
||||
auto value_node = topk_cnode->input(2)->cast<ValueNodePtr>();
|
||||
|
@ -82,5 +92,39 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
|
|||
EXPECT_EQ(tensor->shape().size(), 1);
|
||||
EXPECT_EQ(tensor->shape()[0], 4);
|
||||
}
|
||||
|
||||
TEST_F(TestHWTopKSplit, test_topk_no_split) {
|
||||
/*
|
||||
* def before(input):
|
||||
* topk = TopKSplit(input)
|
||||
* output = tuple_getitem(topk, 0)
|
||||
* return output
|
||||
*/
|
||||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before");
|
||||
std::vector<int> shp{4, 4};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kernel_graph = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
CNodePtr topk_cnode = GetTopkCNodeFromKernelGraph(kernel_graph);
|
||||
EXPECT_EQ(topk_cnode->inputs().size(), 3);
|
||||
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames);
|
||||
EXPECT_EQ(input_names_vec.size(), 2);
|
||||
std::unordered_set<size_t> attr_index{1};
|
||||
ConstInputToAttr(topk_cnode, attr_index);
|
||||
EXPECT_EQ(topk_cnode->inputs().size(), 2);
|
||||
input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames);
|
||||
EXPECT_EQ(input_names_vec.size(), 1);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>());
|
||||
auto topk_split = std::make_shared<opt::TopKSplit>();
|
||||
topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>();
|
||||
pm->AddPass(topk_split);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
|
||||
EXPECT_EQ(topk_cnode, GetTopkCNodeFromKernelGraph(new_graph));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue