!12316 [GraphKernel] Exclude special node in expander and basic fusion.

From: @tronzhang
Reviewed-by: @gaoxiong1,@anyrenwei
Signed-off-by: @anyrenwei
This commit is contained in:
mindspore-ci-bot 2021-02-10 14:32:25 +08:00 committed by Gitee
commit 5394ff8ac3
4 changed files with 29 additions and 2 deletions

View File

@ -140,7 +140,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>();
if (node == nullptr || fused_ops->count(node)) {
if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) {
continue;
}
bool is_fusible_op = IsFusibleOp(node);

View File

@ -147,7 +147,8 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(mng);
for (const auto &n : todos) {
auto node = n->cast<CNodePtr>();
if (node == nullptr || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || !CanExpand(node)) {
if (node == nullptr || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) ||
!CanExpand(node)) {
continue;
}

View File

@ -773,5 +773,30 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
// Set attr secondly.
AnfAlgo::SetNodeAttr(key, value, node);
}
bool IsKeepBasicNode(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
static std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
"aggregate", "aggregate_input_indexx"};
static std::vector<std::function<bool(const AnfNodePtr &node)>> attrs_with_value = {
[](const AnfNodePtr &n) -> bool { return AnfAlgo::GetBooleanAttr(n, "skip"); }};
// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
// If node contain attribute in attrs_with_value, it only have to keep basic when the check result is true.
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
[&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); }) ||
std::any_of(attrs_with_value.cbegin(), attrs_with_value.cend(),
[&cnode](std::function<bool(const AnfNodePtr &node)> func) -> bool { return func(cnode); })) {
return true;
}
return false;
}
} // namespace opt
} // namespace mindspore

View File

@ -89,6 +89,7 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
bool IsKeepBasicNode(const AnfNodePtr &node);
template <typename T>
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {