forked from OSSInnovation/mindspore
!14364 [GraphKernel]Remove depend stuff because of its order implication.
From: @tronzhang Reviewed-by: @gaoxiong1,@anyrenwei Signed-off-by: @anyrenwei
This commit is contained in:
commit
d0dd75c4b5
|
@ -15,12 +15,15 @@
|
|||
*/
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
|
|
|
@ -1,166 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/depend_formater.h"
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
bool RemoveRedundantDepend(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
|
||||
const auto &users = mng->node_users()[node];
|
||||
std::vector<std::pair<AnfNodePtr, int>> sons;
|
||||
for (const auto &[user, index] : users) {
|
||||
if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
|
||||
sons.emplace_back(user, index);
|
||||
continue;
|
||||
}
|
||||
auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin());
|
||||
sons.emplace_back(fake_first_grad_son, grad_index);
|
||||
}
|
||||
|
||||
AnfNodePtrList latter_to_delete;
|
||||
for (const auto &[son, index] : sons) {
|
||||
if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
|
||||
latter_to_delete.push_back(son);
|
||||
}
|
||||
|
||||
if (latter_to_delete.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin();
|
||||
if (latter_to_delete.size() == sons.size()) {
|
||||
// Left one Depend node relation and delete others!
|
||||
++delete_begin;
|
||||
}
|
||||
for (; delete_begin != latter_to_delete.end(); ++delete_begin) {
|
||||
auto depend_anfnode = *delete_begin;
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr FindPatronNode(const FuncGraphPtr &main_graph, const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtr patron_node;
|
||||
|
||||
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(return_cnode);
|
||||
auto output_node = return_cnode->input(kFirstDataInputIndex);
|
||||
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
|
||||
auto output_cnode = output_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
patron_node = output_cnode->input(kFirstDataInputIndex);
|
||||
} else {
|
||||
patron_node = output_node;
|
||||
}
|
||||
|
||||
return patron_node;
|
||||
}
|
||||
|
||||
void AddDepends(const AnfNodePtr &stable_node, const AnfNodePtrList &free_nodes, const FuncGraphPtr &main_graph,
|
||||
const FuncGraphManagerPtr &mng) {
|
||||
AnfNodePtr modified_node = stable_node;
|
||||
for (const auto &free_node : free_nodes) {
|
||||
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), modified_node, free_node};
|
||||
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
||||
depend_cnode->set_abstract(modified_node->abstract());
|
||||
main_graph->AddNode(depend_cnode);
|
||||
modified_node = depend_cnode;
|
||||
}
|
||||
|
||||
if (!free_nodes.empty()) {
|
||||
mng->Replace(stable_node, modified_node);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool DependFormater::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
// 1. Try to remove redundant depend.
|
||||
bool changed = false;
|
||||
auto nodes = TopoSort(func_graph->get_return());
|
||||
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) -> void {
|
||||
if (HasAbstractMonad(node)) {
|
||||
return;
|
||||
}
|
||||
if (RemoveRedundantDepend(node, mng)) {
|
||||
changed = true;
|
||||
}
|
||||
});
|
||||
|
||||
// Should re-toposort for changed graph.
|
||||
if (changed) {
|
||||
nodes = TopoSort(func_graph->get_return());
|
||||
}
|
||||
|
||||
// 2. Move depend to tail of graph.
|
||||
AnfNodePtrList old_depends;
|
||||
AnfNodePtrList free_nodes;
|
||||
|
||||
// Find depend and its free nodes.
|
||||
for (const auto &node : nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
|
||||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
old_depends.push_back(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) {
|
||||
auto attach_node = cnode->input(id);
|
||||
if (!IsPrimitiveCNode(attach_node, prim::kPrimDepend)) {
|
||||
continue;
|
||||
}
|
||||
free_nodes.push_back(attach_node);
|
||||
}
|
||||
}
|
||||
|
||||
if (old_depends.empty()) {
|
||||
return changed;
|
||||
}
|
||||
|
||||
// Delete old depend.
|
||||
for (const auto &depend_anfnode : old_depends) {
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kControlDependPriorIndex);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
|
||||
// Add new depend node in tail.
|
||||
AnfNodePtr patron_node = FindPatronNode(func_graph, mng);
|
||||
AddDepends(patron_node, free_nodes, func_graph, mng);
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,37 +0,0 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class DependFormater : public Pass {
|
||||
public:
|
||||
DependFormater() : Pass("depend_formater") {}
|
||||
~DependFormater() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
using DependFormaterPtr = std::shared_ptr<DependFormater>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
|
|
@ -26,7 +26,6 @@
|
|||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/clean_all_in_once.h"
|
||||
#include "backend/optimizer/graph_kernel/depend_formater.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
|
@ -50,9 +49,6 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
|
|||
// Change Assign(p, a, U) to Assign(Depend(p, U), a)
|
||||
pm->AddPass(std::make_shared<SplitAssign>());
|
||||
|
||||
// Move the Depend nodes to the bottom of graph
|
||||
pm->AddPass(std::make_shared<DependFormater>());
|
||||
|
||||
// Reorder TransData-Cast to Cast-TransData,
|
||||
if (is_ascend) {
|
||||
pm->AddPass(std::make_shared<ReorderOps>());
|
||||
|
@ -142,8 +138,6 @@ PassManagerPtr GraphKernelOptimizer::Combine() {
|
|||
auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine");
|
||||
// Enable parallel fusion
|
||||
if (is_gpu) {
|
||||
// Prevent fake loop in parallel fusion
|
||||
pm->AddPass(std::make_shared<DependFormater>());
|
||||
// Do parallel fusion for gpu device
|
||||
pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7)));
|
||||
}
|
||||
|
|
|
@ -142,10 +142,9 @@ void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
|
||||
for (const auto &getitem_user_iter : getitem_users) {
|
||||
auto getitem_user = getitem_user_iter.first;
|
||||
// 1. A previous pass `DependFormater` has ensured that all data users are directly link to its
|
||||
// input, without Depend node.
|
||||
// 1. Data users may not link directly to its input, they may segregated by Depend node.
|
||||
// 2. If the `cnode` has another path to the getitem_user, it's unnecessary to add update_state and load node to
|
||||
// keep exec_order.
|
||||
// keep exec_order.
|
||||
if (HasPathToParamUser(cnode, getitem_user, getitem)) {
|
||||
mng->Replace(getitem, assign_to);
|
||||
continue;
|
||||
|
|
|
@ -95,33 +95,6 @@ void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn,
|
|||
}
|
||||
}
|
||||
|
||||
void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
OrderedSet<AnfNodePtr> to_be_through_pass;
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
|
||||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make attached nodes deattach with node.
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) {
|
||||
auto attach_node = cnode->input(id);
|
||||
if (auto iter = node_rels->find(attach_node); iter != node_rels->end()) {
|
||||
iter->second.nexts.erase(node);
|
||||
}
|
||||
if (auto &cnode_pres = node_rel.pres; cnode_pres.count(attach_node) != 0) {
|
||||
cnode_pres.erase(attach_node);
|
||||
}
|
||||
}
|
||||
to_be_through_pass.insert(node);
|
||||
}
|
||||
|
||||
// Eliminate depend node of node relations.
|
||||
ProcessThroughPassCNode([&to_be_through_pass](const AnfNodePtr &node) { return to_be_through_pass.count(node) > 0; },
|
||||
node_rels);
|
||||
}
|
||||
|
||||
void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
|
||||
AnfNodePtrList latter_to_be_erased;
|
||||
for (auto &[node, node_rel] : (*node_rels)) {
|
||||
|
@ -441,8 +414,7 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An
|
|||
|
||||
auto prior_node = get_info(node);
|
||||
for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
|
||||
// Parameter for ControlDepend when depend mode is 1.
|
||||
if (!input->isa<CNode>() && !input->isa<Parameter>()) {
|
||||
if (!input->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto behind_node = get_info(input);
|
||||
|
@ -451,13 +423,11 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An
|
|||
}
|
||||
}
|
||||
|
||||
ProcessDependCNode(&node_rels);
|
||||
ProcessThroughPassCNode(
|
||||
[](const AnfNodePtr &node) {
|
||||
return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
|
||||
},
|
||||
&node_rels);
|
||||
ProcessThroughPassCNode([](const AnfNodePtr &node) { return node->isa<Parameter>(); }, &node_rels);
|
||||
ProcessTailMakeTupleCNode(&node_rels);
|
||||
ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
|
||||
|
||||
|
@ -707,51 +677,6 @@ void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &pa
|
|||
SetFusionInfoAttrToNode(attach_node, parallel_info);
|
||||
}
|
||||
|
||||
void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
|
||||
const auto &users = mng->node_users()[node];
|
||||
std::vector<std::pair<AnfNodePtr, int>> sons;
|
||||
for (const auto &[user, index] : users) {
|
||||
if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
|
||||
sons.emplace_back(user, index);
|
||||
continue;
|
||||
}
|
||||
auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin());
|
||||
sons.emplace_back(fake_first_grad_son, grad_index);
|
||||
}
|
||||
|
||||
AnfNodePtrList latter_to_delete;
|
||||
for (const auto &[son, index] : sons) {
|
||||
if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) {
|
||||
continue;
|
||||
}
|
||||
|
||||
latter_to_delete.push_back(son);
|
||||
}
|
||||
|
||||
if (latter_to_delete.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin();
|
||||
if (latter_to_delete.size() == sons.size()) {
|
||||
// Left one Depend node relation and delete others!
|
||||
++delete_begin;
|
||||
}
|
||||
for (; delete_begin != latter_to_delete.end(); ++delete_begin) {
|
||||
auto depend_anfnode = *delete_begin;
|
||||
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_cnode);
|
||||
auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend);
|
||||
mng->Replace(depend_anfnode, depend_prior_node);
|
||||
}
|
||||
}
|
||||
|
||||
void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo ¶llel_info) {
|
||||
auto fusion_type = parallel_info.fusion_info()->FusionType();
|
||||
AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
|
||||
|
@ -776,7 +701,6 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
|
|||
SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
|
||||
AnfNodePtr sg_node;
|
||||
std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
|
||||
PostProcessForNewSubGraphCNode(sg_node, kernel_graph);
|
||||
AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
|
||||
DumpParallelFusionDetail(fuse_nodes, sg_node);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue