delete depend format for its order use

This commit is contained in:
tronzhang 2021-03-27 16:19:37 +08:00
parent eb861aa93e
commit 4182c1f02a
6 changed files with 8 additions and 291 deletions

View File

@ -15,12 +15,15 @@
*/
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include <memory>
#include <algorithm>
#include <map>
#include <memory>
#include <unordered_set>
#include <unordered_map>
#include <vector>
#include <utility>
#include <set>
#include <string>
#include <vector>
#include "base/core_ops.h"
#include "ir/graph_utils.h"

View File

@ -1,166 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/depend_formater.h"
#include <tuple>
#include <utility>
#include <vector>
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
namespace mindspore {
namespace opt {
namespace {
bool RemoveRedundantDepend(const AnfNodePtr &node, const FuncGraphManagerPtr &mng) {
const auto &users = mng->node_users()[node];
std::vector<std::pair<AnfNodePtr, int>> sons;
for (const auto &[user, index] : users) {
if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
sons.emplace_back(user, index);
continue;
}
auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin());
sons.emplace_back(fake_first_grad_son, grad_index);
}
AnfNodePtrList latter_to_delete;
for (const auto &[son, index] : sons) {
if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) {
continue;
}
latter_to_delete.push_back(son);
}
if (latter_to_delete.empty()) {
return false;
}
std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin();
if (latter_to_delete.size() == sons.size()) {
// Left one Depend node relation and delete others!
++delete_begin;
}
for (; delete_begin != latter_to_delete.end(); ++delete_begin) {
auto depend_anfnode = *delete_begin;
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend);
mng->Replace(depend_anfnode, depend_prior_node);
}
return true;
}
AnfNodePtr FindPatronNode(const FuncGraphPtr &main_graph, const FuncGraphManagerPtr &mng) {
AnfNodePtr patron_node;
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_cnode);
auto output_node = return_cnode->input(kFirstDataInputIndex);
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
auto output_cnode = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
patron_node = output_cnode->input(kFirstDataInputIndex);
} else {
patron_node = output_node;
}
return patron_node;
}
void AddDepends(const AnfNodePtr &stable_node, const AnfNodePtrList &free_nodes, const FuncGraphPtr &main_graph,
const FuncGraphManagerPtr &mng) {
AnfNodePtr modified_node = stable_node;
for (const auto &free_node : free_nodes) {
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), modified_node, free_node};
auto depend_cnode = main_graph->NewCNode(d_inputs);
depend_cnode->set_abstract(modified_node->abstract());
main_graph->AddNode(depend_cnode);
modified_node = depend_cnode;
}
if (!free_nodes.empty()) {
mng->Replace(stable_node, modified_node);
}
}
} // namespace
bool DependFormater::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
// 1. Try to remove redundant depend.
bool changed = false;
auto nodes = TopoSort(func_graph->get_return());
std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) -> void {
if (HasAbstractMonad(node)) {
return;
}
if (RemoveRedundantDepend(node, mng)) {
changed = true;
}
});
// Should re-toposort for changed graph.
if (changed) {
nodes = TopoSort(func_graph->get_return());
}
// 2. Move depend to tail of graph.
AnfNodePtrList old_depends;
AnfNodePtrList free_nodes;
// Find depend and its free nodes.
for (const auto &node : nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
continue;
}
old_depends.push_back(node);
auto cnode = node->cast<CNodePtr>();
for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) {
auto attach_node = cnode->input(id);
if (!IsPrimitiveCNode(attach_node, prim::kPrimDepend)) {
continue;
}
free_nodes.push_back(attach_node);
}
}
if (old_depends.empty()) {
return changed;
}
// Delete old depend.
for (const auto &depend_anfnode : old_depends) {
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
auto depend_prior_node = depend_cnode->input(kControlDependPriorIndex);
mng->Replace(depend_anfnode, depend_prior_node);
}
// Add new depend node in tail.
AnfNodePtr patron_node = FindPatronNode(func_graph, mng);
AddDepends(patron_node, free_nodes, func_graph, mng);
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,37 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_
#include <map>
#include <memory>
#include "backend/optimizer/common/pass.h"
#include "ir/func_graph.h"
namespace mindspore {
namespace opt {
class DependFormater : public Pass {
public:
DependFormater() : Pass("depend_formater") {}
~DependFormater() override = default;
bool Run(const FuncGraphPtr &graph) override;
};
using DependFormaterPtr = std::shared_ptr<DependFormater>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_DEPEND_FORMATER_H_

View File

@ -27,7 +27,6 @@
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/clean_all_in_once.h"
#include "backend/optimizer/graph_kernel/depend_formater.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
@ -51,9 +50,6 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
// Change Assign(p, a, U) to Assign(Depend(p, U), a)
pm->AddPass(std::make_shared<SplitAssign>());
// Move the Depend nodes to the bottom of graph
pm->AddPass(std::make_shared<DependFormater>());
// Reorder TransData-Cast to Cast-TransData,
if (is_ascend) {
pm->AddPass(std::make_shared<ReorderOps>());
@ -145,8 +141,6 @@ PassManagerPtr GraphKernelOptimizer::Combine() {
auto pm = std::make_shared<PassManager>("graphkernel_stage6_combine");
// Enable parallel fusion
if (is_gpu) {
// Prevent fake loop in parallel fusion
pm->AddPass(std::make_shared<DependFormater>());
// Do parallel fusion for gpu device
pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7)));
}

View File

@ -142,10 +142,9 @@ void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &
for (const auto &getitem_user_iter : getitem_users) {
auto getitem_user = getitem_user_iter.first;
// 1. A previous pass `DependFormater` has ensured that all data users are directly link to its
// input, without Depend node.
// 1. Data users may not link directly to its input, they may segregated by Depend node.
// 2. If the `cnode` has another path to the getitem_user, it's unnecessary to add update_state and load node to
// keep exec_order.
// keep exec_order.
if (HasPathToParamUser(cnode, getitem_user, getitem)) {
mng->Replace(getitem, assign_to);
continue;

View File

@ -95,33 +95,6 @@ void ProcessThroughPassCNode(std::function<bool(const AnfNodePtr &)> pass_fn,
}
}
void ProcessDependCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
OrderedSet<AnfNodePtr> to_be_through_pass;
for (auto &[node, node_rel] : (*node_rels)) {
if (!IsPrimitiveCNode(node, prim::kPrimDepend) ||
HasAbstractMonad(node->cast<CNodePtr>()->input(kDependAttachNodeIndex))) {
continue;
}
// Make attached nodes deattach with node.
auto cnode = node->cast<CNodePtr>();
for (size_t id = kDependAttachNodeIndex; id < cnode->inputs().size(); ++id) {
auto attach_node = cnode->input(id);
if (auto iter = node_rels->find(attach_node); iter != node_rels->end()) {
iter->second.nexts.erase(node);
}
if (auto &cnode_pres = node_rel.pres; cnode_pres.count(attach_node) != 0) {
cnode_pres.erase(attach_node);
}
}
to_be_through_pass.insert(node);
}
// Eliminate depend node of node relations.
ProcessThroughPassCNode([&to_be_through_pass](const AnfNodePtr &node) { return to_be_through_pass.count(node) > 0; },
node_rels);
}
void ProcessTailMakeTupleCNode(OrderedMap<AnfNodePtr, NodeRelation> *node_rels) {
AnfNodePtrList latter_to_be_erased;
for (auto &[node, node_rel] : (*node_rels)) {
@ -441,8 +414,7 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An
auto prior_node = get_info(node);
for (const auto &input : (node->cast<CNodePtr>())->inputs()) {
// Parameter for ControlDepend when depend mode is 1.
if (!input->isa<CNode>() && !input->isa<Parameter>()) {
if (!input->isa<CNode>()) {
continue;
}
auto behind_node = get_info(input);
@ -451,13 +423,11 @@ OrderedMap<AnfNodePtr, NodeRelation> ParallelOpFusion::GenAnalysisGraph(const An
}
}
ProcessDependCNode(&node_rels);
ProcessThroughPassCNode(
[](const AnfNodePtr &node) {
return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem});
},
&node_rels);
ProcessThroughPassCNode([](const AnfNodePtr &node) { return node->isa<Parameter>(); }, &node_rels);
ProcessTailMakeTupleCNode(&node_rels);
ProcessLocalStructure(&node_rels, &virtual_noout_nodes_, &ignore_noin_nodes_);
@ -707,51 +677,6 @@ void ParallelOpFusion::SetFusedParallelOpAttrToReturnNode(const ParallelInfo &pa
SetFusionInfoAttrToNode(attach_node, parallel_info);
}
void PostProcessForNewSubGraphCNode(const AnfNodePtr &node, const std::shared_ptr<session::KernelGraph> &kernel_graph) {
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
const auto &users = mng->node_users()[node];
std::vector<std::pair<AnfNodePtr, int>> sons;
for (const auto &[user, index] : users) {
if (!IsPrimitiveCNode(user, prim::kPrimTupleGetItem)) {
sons.emplace_back(user, index);
continue;
}
auto &[fake_first_grad_son, grad_index] = *((mng->node_users()[user]).begin());
sons.emplace_back(fake_first_grad_son, grad_index);
}
AnfNodePtrList latter_to_delete;
for (const auto &[son, index] : sons) {
if (!IsPrimitiveCNode(son, prim::kPrimDepend) || index != kDependAttachNodeIndex) {
continue;
}
latter_to_delete.push_back(son);
}
if (latter_to_delete.empty()) {
return;
}
std::vector<AnfNodePtr>::iterator delete_begin = latter_to_delete.begin();
if (latter_to_delete.size() == sons.size()) {
// Left one Depend node relation and delete others!
++delete_begin;
}
for (; delete_begin != latter_to_delete.end(); ++delete_begin) {
auto depend_anfnode = *delete_begin;
auto depend_cnode = depend_anfnode->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_cnode);
auto depend_prior_node = depend_cnode->input(kRealInputIndexInDepend);
mng->Replace(depend_anfnode, depend_prior_node);
}
}
void ParallelOpFusion::SetFusionInfoAttrToNode(const AnfNodePtr &node, const ParallelInfo &parallel_info) {
auto fusion_type = parallel_info.fusion_info()->FusionType();
AnfAlgo::SetNodeAttr(kAttrParallelFusionType, MakeValue<std::string>(fusion_type), node);
@ -776,7 +701,6 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
SetFusedParallelOpAttrToReturnNode(parallel_infos[i]);
AnfNodePtr sg_node;
std::tie(sg_node, std::ignore) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "parallel");
PostProcessForNewSubGraphCNode(sg_node, kernel_graph);
AnfAlgo::SetNodeAttr(kAttrCompositeType, MakeValue("parallel_fusion"), sg_node);
DumpParallelFusionDetail(fuse_nodes, sg_node);
}