From 8fb7d11ecbfb6f5776cb9f9e34d90f1972683e74 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Thu, 26 Nov 2020 18:39:57 +0800 Subject: [PATCH] fix topk help 4096 --- .../optimizer/ascend/ir_fission/topk_split.cc | 19 ++++++++----------- .../ascend/ir_fission/topk_split_test.cc | 2 +- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc index 857dfa1c8fc..f5b968e1a03 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc @@ -26,15 +26,13 @@ #include "runtime/device/kernel_info.h" #include "utils/ms_context.h" -namespace mindspore { -namespace opt { +namespace mindspore::opt { constexpr size_t kFloat16Len = 2; // size of float16; constexpr size_t kTopkIndexK = 1; namespace { -tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { +tensor::TensorPtr CreateTensor() { // 1 create tensor - auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0); - auto last_dim = shape[shape.size() - 1]; + const size_t last_dim = 4096; std::vector indices_shape = {SizeToLong(last_dim * 2)}; TensorTypePtr tensor_type = std::make_shared(kFloat16); MS_EXCEPTION_IF_NULL(tensor_type); @@ -63,8 +61,8 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) { return indices_tensor; } -ValueNodePtr CreateValueNode(const AnfNodePtr &node) { - tensor::TensorPtr indices_tensor = CreateTensor(node); +ValueNodePtr CreateValueNode() { + tensor::TensorPtr indices_tensor = CreateTensor(); MS_EXCEPTION_IF_NULL(indices_tensor); auto indices_const = std::make_shared(indices_tensor); MS_EXCEPTION_IF_NULL(indices_const); @@ -159,14 +157,14 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod MS_EXCEPTION_IF_NULL(value); auto tensor = value->cast(); MS_EXCEPTION_IF_NULL(tensor); - int32_t *data = reinterpret_cast(tensor->data_c()); + auto *data = reinterpret_cast(tensor->data_c()); MS_EXCEPTION_IF_NULL(data); auto new_value_node = std::make_shared(MakeValue(*data)); new_cnode->set_input(kTopkIndexK + 1, new_value_node); std::unordered_set attr_index{kTopkIndexK}; ConstInputToAttr(new_cnode, attr_index); - auto indices_const = CreateValueNode(new_cnode); + auto indices_const = CreateValueNode(); new_cnode->add_input(indices_const); MS_EXCEPTION_IF_NULL(supported_checker_); if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) { @@ -181,5 +179,4 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod return new_cnode; } -} // namespace opt -} // namespace mindspore +} // namespace mindspore::opt diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc index 308bffc050d..95b8db50f31 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/topk_split_test.cc @@ -89,7 +89,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) { EXPECT_TRUE(value_node->value()->isa()); auto tensor = value_node->value()->cast(); EXPECT_EQ(tensor->shape().size(), 1); - EXPECT_EQ(tensor->shape()[0], 8); + EXPECT_EQ(tensor->shape()[0], 4096*2); } TEST_F(TestHWTopKSplit, test_topk_no_split) {