forked from mindspore-Ecosystem/mindspore
fix topk help 4096
This commit is contained in:
parent
57899791d3
commit
8fb7d11ecb
|
@ -26,15 +26,13 @@
|
||||||
#include "runtime/device/kernel_info.h"
|
#include "runtime/device/kernel_info.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore::opt {
|
||||||
namespace opt {
|
|
||||||
constexpr size_t kFloat16Len = 2; // size of float16;
|
constexpr size_t kFloat16Len = 2; // size of float16;
|
||||||
constexpr size_t kTopkIndexK = 1;
|
constexpr size_t kTopkIndexK = 1;
|
||||||
namespace {
|
namespace {
|
||||||
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
tensor::TensorPtr CreateTensor() {
|
||||||
// 1 create tensor
|
// 1 create tensor
|
||||||
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
const size_t last_dim = 4096;
|
||||||
auto last_dim = shape[shape.size() - 1];
|
|
||||||
std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)};
|
std::vector<int64_t> indices_shape = {SizeToLong(last_dim * 2)};
|
||||||
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
||||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||||
|
@ -63,8 +61,8 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||||
return indices_tensor;
|
return indices_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
ValueNodePtr CreateValueNode() {
|
||||||
tensor::TensorPtr indices_tensor = CreateTensor(node);
|
tensor::TensorPtr indices_tensor = CreateTensor();
|
||||||
MS_EXCEPTION_IF_NULL(indices_tensor);
|
MS_EXCEPTION_IF_NULL(indices_tensor);
|
||||||
auto indices_const = std::make_shared<ValueNode>(indices_tensor);
|
auto indices_const = std::make_shared<ValueNode>(indices_tensor);
|
||||||
MS_EXCEPTION_IF_NULL(indices_const);
|
MS_EXCEPTION_IF_NULL(indices_const);
|
||||||
|
@ -159,14 +157,14 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
auto tensor = value->cast<tensor::TensorPtr>();
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
MS_EXCEPTION_IF_NULL(tensor);
|
MS_EXCEPTION_IF_NULL(tensor);
|
||||||
int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c());
|
auto *data = reinterpret_cast<int32_t *>(tensor->data_c());
|
||||||
MS_EXCEPTION_IF_NULL(data);
|
MS_EXCEPTION_IF_NULL(data);
|
||||||
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data));
|
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data));
|
||||||
new_cnode->set_input(kTopkIndexK + 1, new_value_node);
|
new_cnode->set_input(kTopkIndexK + 1, new_value_node);
|
||||||
|
|
||||||
std::unordered_set<size_t> attr_index{kTopkIndexK};
|
std::unordered_set<size_t> attr_index{kTopkIndexK};
|
||||||
ConstInputToAttr(new_cnode, attr_index);
|
ConstInputToAttr(new_cnode, attr_index);
|
||||||
auto indices_const = CreateValueNode(new_cnode);
|
auto indices_const = CreateValueNode();
|
||||||
new_cnode->add_input(indices_const);
|
new_cnode->add_input(indices_const);
|
||||||
MS_EXCEPTION_IF_NULL(supported_checker_);
|
MS_EXCEPTION_IF_NULL(supported_checker_);
|
||||||
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
if (!supported_checker_->CheckAICoreSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||||
|
@ -181,5 +179,4 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
||||||
|
|
||||||
return new_cnode;
|
return new_cnode;
|
||||||
}
|
}
|
||||||
} // namespace opt
|
} // namespace mindspore::opt
|
||||||
} // namespace mindspore
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
|
||||||
EXPECT_TRUE(value_node->value()->isa<tensor::Tensor>());
|
EXPECT_TRUE(value_node->value()->isa<tensor::Tensor>());
|
||||||
auto tensor = value_node->value()->cast<tensor::TensorPtr>();
|
auto tensor = value_node->value()->cast<tensor::TensorPtr>();
|
||||||
EXPECT_EQ(tensor->shape().size(), 1);
|
EXPECT_EQ(tensor->shape().size(), 1);
|
||||||
EXPECT_EQ(tensor->shape()[0], 8);
|
EXPECT_EQ(tensor->shape()[0], 4096*2);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TestHWTopKSplit, test_topk_no_split) {
|
TEST_F(TestHWTopKSplit, test_topk_no_split) {
|
||||||
|
|
Loading…
Reference in New Issue