forked from mindspore-Ecosystem/mindspore
Check topk supported before converting input to attr
This commit is contained in:
parent
c6d21ccd12
commit
ce2a13fcda
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include "device/ascend/kernel_select_ascend.h"
|
||||
#include "kernel/kernel_query.h"
|
||||
#include "kernel/tbe/tbe_kernel_select.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -36,6 +37,16 @@ class KernelSelect {
|
|||
};
|
||||
using KernelSelectPtr = std::shared_ptr<KernelSelect>;
|
||||
|
||||
class SupportedChecker {
|
||||
public:
|
||||
SupportedChecker() = default;
|
||||
virtual ~SupportedChecker() = default;
|
||||
virtual bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) {
|
||||
return kernel::CheckSupported(anf_node, select_kernel_build_info);
|
||||
}
|
||||
};
|
||||
using SupportedCheckerPtr = std::shared_ptr<SupportedChecker>;
|
||||
|
||||
class KernelQuery {
|
||||
public:
|
||||
KernelQuery() = default;
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
#include "pre_activate/ascend/ir_fission/topk_split.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include "pre_activate/common/helper.h"
|
||||
#include "kernel/kernel_build_info.h"
|
||||
#include "utils/utils.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
@ -25,6 +28,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
constexpr size_t kFloat16Len = 2; // size of float16;
|
||||
constexpr size_t kTopkIndexK = 1;
|
||||
namespace {
|
||||
tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
||||
// 1 create tensor
|
||||
|
@ -70,37 +74,68 @@ ValueNodePtr CreateValueNode(const AnfNodePtr &node) {
|
|||
AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), indices_const.get());
|
||||
return indices_const;
|
||||
}
|
||||
|
||||
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
builder.SetInputsDeviceType({kNumberTypeFloat16, kNumberTypeFloat16});
|
||||
builder.SetOutputsDeviceType({kNumberTypeFloat16, kNumberTypeInt32});
|
||||
return builder.Build();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef TopKSplit::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
MS_EXCEPTION_IF_NULL(X);
|
||||
VarPtr X1 = std::make_shared<Var>();
|
||||
VarPtr X2 = std::make_shared<Var>();
|
||||
auto prim = std::make_shared<Primitive>(kTopKOpName);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return VectorRef({prim, X});
|
||||
return VectorRef({prim, X1, X2});
|
||||
}
|
||||
|
||||
const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
auto indices_const = CreateValueNode(node);
|
||||
// set value node as topk's input
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_LOG(INFO) << "already has input size: " << cnode->inputs().size();
|
||||
cnode->add_input(indices_const);
|
||||
// Copy a new node to check supported.
|
||||
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
|
||||
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
new_cnode->set_abstract(cnode->abstract());
|
||||
new_cnode->set_scope(cnode->scope());
|
||||
AnfAlgo::CopyNodeAttrs(cnode, new_cnode);
|
||||
CheckCNodeInputSize(new_cnode, kTopkInputNum);
|
||||
// Convert the tensor input to scalar and convert it to attr
|
||||
auto input_k = new_cnode->input(kTopkIndexK + 1);
|
||||
MS_EXCEPTION_IF_NULL(input_k);
|
||||
if (!IsValueNode<tensor::Tensor>(input_k)) {
|
||||
return nullptr;
|
||||
}
|
||||
ValuePtr value = GetValueNode(input_k);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
int32_t *data = reinterpret_cast<int32_t *>(tensor->data_c());
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
auto new_value_node = std::make_shared<ValueNode>(MakeValue(*data));
|
||||
new_cnode->set_input(kTopkIndexK + 1, new_value_node);
|
||||
|
||||
std::unordered_set<size_t> attr_index{kTopkIndexK};
|
||||
ConstInputToAttr(new_cnode, attr_index);
|
||||
auto indices_const = CreateValueNode(new_cnode);
|
||||
new_cnode->add_input(indices_const);
|
||||
MS_EXCEPTION_IF_NULL(supported_checker_);
|
||||
if (!supported_checker_->CheckSupported(new_cnode, CreateKernelBuildInfo())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (kernel_graph != nullptr) {
|
||||
kernel_graph->AddValueNodeToGraph(indices_const);
|
||||
}
|
||||
|
||||
CNodePtr new_cnode = nullptr;
|
||||
if (kernel_graph == nullptr) {
|
||||
new_cnode = std::make_shared<CNode>(*cnode);
|
||||
} else {
|
||||
new_cnode = kernel_graph->NewCNode(cnode);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(new_cnode);
|
||||
return new_cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -16,15 +16,22 @@
|
|||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TOPK_SPLIT_H_
|
||||
|
||||
#include <memory>
|
||||
#include "pre_activate/common/optimizer.h"
|
||||
#include "pre_activate/ascend/ascend_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TopKSplit : public PatternProcessPass {
|
||||
public:
|
||||
explicit TopKSplit(bool multigraph = true) : PatternProcessPass("topk_split", multigraph) {}
|
||||
explicit TopKSplit(bool multigraph = true)
|
||||
: PatternProcessPass("topk_split", multigraph), supported_checker_(std::make_shared<SupportedChecker>()) {}
|
||||
~TopKSplit() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
SupportedCheckerPtr supported_checker_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -422,5 +422,47 @@ AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, tuple_getitem.get());
|
||||
return tuple_getitem;
|
||||
}
|
||||
|
||||
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
std::vector<std::string> new_input_names;
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
if (input_names == nullptr) {
|
||||
MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
|
||||
return;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto inputs = cnode->inputs();
|
||||
new_inputs.push_back(inputs[0]);
|
||||
bool need_update = false;
|
||||
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||
auto input_node = inputs[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
|
||||
if (i >= input_names_vec.size()) {
|
||||
MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
|
||||
}
|
||||
primitive->set_attr(input_names_vec[i], value_node->value());
|
||||
need_update = true;
|
||||
} else {
|
||||
new_inputs.push_back(input_node);
|
||||
if (i < input_names_vec.size()) {
|
||||
new_input_names.push_back(input_names_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (need_update) {
|
||||
// Update cnode's inputs
|
||||
cnode->set_inputs(new_inputs);
|
||||
// Update cnode's input_names attr
|
||||
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include "ir/func_graph.h"
|
||||
#include "session/kernel_graph.h"
|
||||
#include "common/utils.h"
|
||||
|
@ -86,6 +87,7 @@ constexpr size_t kAdamApplyOneOutputNum = 3;
|
|||
constexpr size_t kBackendTransDataInputNum = 2;
|
||||
constexpr size_t kApplyMomentumInputNum = 6;
|
||||
constexpr size_t kBiasAddInputNum = 3;
|
||||
constexpr size_t kTopkInputNum = 3;
|
||||
|
||||
enum FusedBatchNormInput {
|
||||
kX = 1,
|
||||
|
@ -150,6 +152,8 @@ void RemoveNopNode(session::KernelGraph *const graph);
|
|||
AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx);
|
||||
|
||||
bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node);
|
||||
|
||||
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_HELPER_H_
|
||||
|
|
|
@ -52,7 +52,6 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
|
|||
Register(kFlattenGradOpName, {1});
|
||||
Register(kExpandDimsOpName, {1});
|
||||
Register(kSplitOpName, {0});
|
||||
Register(kTopKOpName, {1});
|
||||
Register(kErfOpName, {1});
|
||||
Register(kSparseApplyAdagradOpName, {2});
|
||||
Register(kResizeNearestNeighborGrad, {1});
|
||||
|
|
|
@ -18,10 +18,10 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <memory>
|
||||
|
||||
#include "pre_activate/pass/const_input_to_attr_registry.h"
|
||||
#include "pre_activate/common/helper.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "operator/ops.h"
|
||||
|
@ -29,50 +29,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set<size_t> &input_attrs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
std::vector<std::string> new_input_names;
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_names = primitive->GetAttr(kAttrInputNames);
|
||||
if (input_names == nullptr) {
|
||||
MS_LOG(DEBUG) << "input_names are nullptr in cnode[" + cnode->DebugString() + "]";
|
||||
return;
|
||||
}
|
||||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names);
|
||||
auto inputs = cnode->inputs();
|
||||
new_inputs.push_back(inputs[0]);
|
||||
bool need_update = false;
|
||||
for (size_t i = 0; i < inputs.size() - 1; ++i) {
|
||||
auto input_node = inputs[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(input_node);
|
||||
if (input_attrs.find(i) != input_attrs.end() && input_node->isa<ValueNode>()) {
|
||||
auto value_node = input_node->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
MS_LOG(DEBUG) << "start erase input[" << i << "] of cnode[" + cnode->DebugString() + "]";
|
||||
if (i >= input_names_vec.size()) {
|
||||
MS_LOG(EXCEPTION) << "index " << i << " is larger than input names size [" << input_names_vec.size() << "]";
|
||||
}
|
||||
primitive->set_attr(input_names_vec[i], value_node->value());
|
||||
need_update = true;
|
||||
} else {
|
||||
new_inputs.push_back(input_node);
|
||||
if (i < input_names_vec.size()) {
|
||||
new_input_names.push_back(input_names_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (need_update) {
|
||||
// Update cnode's inputs
|
||||
cnode->set_inputs(new_inputs);
|
||||
// Update cnode's input_names attr
|
||||
primitive->set_attr(kAttrInputNames, MakeValue(new_input_names));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
if (node == nullptr || !AnfAlgo::IsRealCNodeKernel(node)) {
|
||||
|
|
|
@ -17,8 +17,13 @@
|
|||
#include "common/backend_common_test.h"
|
||||
#include "common/py_func_graph_fetcher.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "pre_activate/ascend/ir_fission/topk_split.h"
|
||||
#include "pre_activate/pass/convert_const_input_to_attr.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "pre_activate/ascend/ir_fission/topk_split.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -30,6 +35,15 @@ class TestHWTopKSplit : public BackendCommon {
|
|||
UT::PyFuncGraphFetcher get_py_fun_;
|
||||
};
|
||||
|
||||
class MockSupportedChecker : public SupportedChecker {
|
||||
public:
|
||||
MockSupportedChecker() = default;
|
||||
~MockSupportedChecker() override = default;
|
||||
bool CheckSupported(const AnfNodePtr &anf_node, const kernel::KernelBuildInfoPtr &select_kernel_build_info) override {
|
||||
return true;
|
||||
}
|
||||
}; // namespace opt
|
||||
|
||||
TEST_F(TestHWTopKSplit, test_topk_split) {
|
||||
/*
|
||||
* def before(input):
|
||||
|
@ -40,19 +54,25 @@ TEST_F(TestHWTopKSplit, test_topk_split) {
|
|||
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_topk_split", "before");
|
||||
std::vector<int> shp{4, 4};
|
||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
g->parameters()[0]->set_abstract(x_abstract);
|
||||
auto ret = g->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
auto tuple_getitem = ret->input(1);
|
||||
EXPECT_NE(tuple_getitem, nullptr);
|
||||
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
|
||||
topk->set_abstract(x_abstract);
|
||||
AbstractBasePtrList args_spec_list{x_abstract};
|
||||
auto kernel_graph = GetKernelGraph(g, args_spec_list);
|
||||
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>();
|
||||
pm->AddPass(std::make_shared<opt::TopKSplit>());
|
||||
pm->AddPass(std::make_shared<opt::ConvertConstInputToAttr>());
|
||||
auto topk_split = std::make_shared<opt::TopKSplit>();
|
||||
topk_split->supported_checker_ = std::make_shared<MockSupportedChecker>();
|
||||
pm->AddPass(topk_split);
|
||||
optimizer->AddPassManager(pm);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(g);
|
||||
FuncGraphPtr new_graph = optimizer->Optimize(kernel_graph);
|
||||
|
||||
auto ret = new_graph->get_return();
|
||||
EXPECT_NE(ret, nullptr);
|
||||
auto make_tuple = ret->input(1);
|
||||
EXPECT_NE(make_tuple, nullptr);
|
||||
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
|
||||
EXPECT_NE(tuple_getitem, nullptr);
|
||||
auto topk = tuple_getitem->cast<CNodePtr>()->input(1);
|
||||
auto topk_cnode = topk->cast<CNodePtr>();
|
||||
EXPECT_EQ(topk_cnode->inputs().size(), 3);
|
||||
EXPECT_TRUE(topk_cnode->input(2)->isa<ValueNode>());
|
||||
|
|
|
@ -35,7 +35,7 @@ def test_topk_split(tag):
|
|||
|
||||
@fns
|
||||
def before(input):
|
||||
topk = TopK(input)
|
||||
topk = TopK(input, 2)
|
||||
output = tuple_getitem(topk, 0)
|
||||
return output
|
||||
|
||||
|
|
Loading…
Reference in New Issue