!1554 Check the size of topk input names before converting input to attr

Merge pull request !1554 from YuJianfeng/master
This commit is contained in:
mindspore-ci-bot 2020-05-28 11:44:33 +08:00 committed by Gitee
commit ab94e92cd1
2 changed files with 59 additions and 9 deletions

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "pre_activate/ascend/ir_fission/topk_split.h"
#include <string>
#include <vector>
#include <memory>
#include <unordered_set>
@ -102,6 +103,11 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
// set value node as topk's input
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode, kAttrInputNames);
if (input_names_vec.size() < kTopkIndexK + 1) {
MS_LOG(INFO) << "The input k of topk has been converted to attr";
return nullptr;
}
// Copy a new node to check supported.
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());

View File

@ -19,6 +19,7 @@
#include "device/kernel_info.h"
#include "pre_activate/pass/convert_const_input_to_attr.h"
#include "debug/anf_ir_dump.h"
#include "session/anf_runtime_algorithm.h"
#define private public
#define protected public
#include "pre_activate/ascend/ir_fission/topk_split.h"
@ -32,6 +33,21 @@ class TestHWTopKSplit : public BackendCommon {
TestHWTopKSplit() : get_py_fun_("gtest_input.pre_activate.topk_split_test", true) {}
~TestHWTopKSplit() override = default;
CNodePtr GetTopkCNodeFromKernelGraph(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto ret = func_graph->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto make_tuple = ret->input(1);
MS_EXCEPTION_IF_NULL(make_tuple);
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(tuple_getitem);
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(topk);
auto topk_cnode = topk->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(topk_cnode);
return topk_cnode;
}
UT::PyFuncGraphFetcher get_py_fun_;
};
@ -39,7 +55,8 @@ class MockSupportedChecker : public SupportedChecker {
public:
MockSupportedChecker() = default;
~MockSupportedChecker() override = default;
bool CheckAiCoreSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
bool CheckAiCoreSupported(const AnfNodePtr &anf_node,
const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
return true;
}
}; // namespace opt
@ -66,14 +83,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
auto ret = new_graph->get_return();
EXPECT_NE(ret, nullptr);
auto make_tuple = ret->input(1);
EXPECT_NE(make_tuple, nullptr);
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
EXPECT_NE(tuple_getitem, nullptr);
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
auto topk_cnode = topk->cast<CNodePtr>();
auto topk_cnode = GetTopkCNodeFromKernelGraph(new_graph);
EXPECT_EQ(topk_cnode->inputs().size(), 3);
EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>());
auto value_node = topk_cnode->input(2)->cast<ValueNodePtr>();
@ -82,5 +92,39 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
EXPECT_EQ(tensor->shape().size(), 1);
EXPECT_EQ(tensor->shape()[0], 4);
}
TEST_F(TestHWTopKSplit, test_topk_no_split) {
/*
* def before(input):
* topk = TopKSplit(input)
* output = tuple_getitem(topk, 0)
* return output
*/
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before");
std::vector<int> shp{4, 4};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kernel_graph = GetKernelGraph(g, args_spec_list);
CNodePtr topk_cnode = GetTopkCNodeFromKernelGraph(kernel_graph);
EXPECT_EQ(topk_cnode->inputs().size(), 3);
auto input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames);
EXPECT_EQ(input_names_vec.size(), 2);
std::unordered_set<size_t> attr_index{1};
ConstInputToAttr(topk_cnode, attr_index);
EXPECT_EQ(topk_cnode->inputs().size(), 2);
input_names_vec = AnfAlgo::GetNodeAttr<std::vector<std::string>>(topk_cnode, kAttrInputNames);
EXPECT_EQ(input_names_vec.size(), 1);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>());
auto topk_split = std::make_shared<opt::TopKSplit>();
topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>();
pm->AddPass(topk_split);
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
EXPECT_EQ(topk_cnode, GetTopkCNodeFromKernelGraph(new_graph));
}
} // namespace opt
} // namespace mindspore