forked from mindspore-Ecosystem/mindspore
!12316 [GraphKernel] Exclude special node in expander and basic fusion.
From: @tronzhang Reviewed-by: @gaoxiong1,@anyrenwei Signed-off-by: @anyrenwei
This commit is contained in:
commit
5394ff8ac3
|
@ -140,7 +140,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto node = (*iter)->cast<CNodePtr>();
|
||||
if (node == nullptr || fused_ops->count(node)) {
|
||||
if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) {
|
||||
continue;
|
||||
}
|
||||
bool is_fusible_op = IsFusibleOp(node);
|
||||
|
|
|
@ -147,7 +147,8 @@ bool GraphKernelExpander::DoExpand(const FuncGraphPtr &func_graph) {
|
|||
MS_EXCEPTION_IF_NULL(mng);
|
||||
for (const auto &n : todos) {
|
||||
auto node = n->cast<CNodePtr>();
|
||||
if (node == nullptr || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) || !CanExpand(node)) {
|
||||
if (node == nullptr || IsKeepBasicNode(node) || !AnfAlgo::IsRealKernel(node) || AnfAlgo::IsGraphKernel(node) ||
|
||||
!CanExpand(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -773,5 +773,30 @@ void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfN
|
|||
// Set attr secondly.
|
||||
AnfAlgo::SetNodeAttr(key, value, node);
|
||||
}
|
||||
|
||||
bool IsKeepBasicNode(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
static std::vector<std::string> contagious_attrs = {"inplace_group", "inplace_algo", "inplace_output_index",
|
||||
"aggregate", "aggregate_input_indexx"};
|
||||
static std::vector<std::function<bool(const AnfNodePtr &node)>> attrs_with_value = {
|
||||
[](const AnfNodePtr &n) -> bool { return AnfAlgo::GetBooleanAttr(n, "skip"); }};
|
||||
|
||||
// If node contain attribute in contagious_attrs, it have to keep basic no matter what the value is.
|
||||
// If node contain attribute in attrs_with_value, it only have to keep basic when the check result is true.
|
||||
if (std::any_of(contagious_attrs.cbegin(), contagious_attrs.cend(),
|
||||
[&cnode](const std::string &attr_name) -> bool { return AnfAlgo::HasNodeAttr(attr_name, cnode); }) ||
|
||||
std::any_of(attrs_with_value.cbegin(), attrs_with_value.cend(),
|
||||
[&cnode](std::function<bool(const AnfNodePtr &node)> func) -> bool { return func(cnode); })) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -89,6 +89,7 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
|
|||
|
||||
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
|
||||
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
|
||||
bool IsKeepBasicNode(const AnfNodePtr &node);
|
||||
|
||||
template <typename T>
|
||||
ValueNodePtr CreateScalarTensorValueNode(const DataInfo &info, T value, size_t data_length) {
|
||||
|
|
Loading…
Reference in New Issue