forked from mindspore-Ecosystem/mindspore
!8918 checkcircle should consider the edges coming from controldepent nodes
From: @lingyunli63 Reviewed-by: Signed-off-by:
This commit is contained in:
commit
08dc1481c7
|
@ -47,12 +47,13 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
|
|||
return is_fusable ? FOLLOW : EXCLUDE;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
|
||||
// Search fusable nodes according input direction.
|
||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||
if (used_nodes.size() > 1) {
|
||||
used_nodes = RemoveCircle(used_nodes, false);
|
||||
used_nodes = RemoveCircle(used_nodes, dep_pri, false);
|
||||
}
|
||||
TopoSortForNodeList(&used_nodes);
|
||||
return used_nodes;
|
||||
|
@ -78,7 +79,8 @@ void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &u
|
|||
}
|
||||
|
||||
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv) {
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||
AnfNodeSet outputs_set;
|
||||
for (auto out : *outputs) {
|
||||
outputs_set.insert(out);
|
||||
|
@ -112,6 +114,7 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
|
|||
auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
|
||||
mng->Replace(control_depend_node, new_control_depend);
|
||||
has_erase_outs = true;
|
||||
UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend);
|
||||
}
|
||||
} else {
|
||||
it++;
|
||||
|
@ -120,7 +123,8 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
|
|||
return has_erase_outs;
|
||||
}
|
||||
|
||||
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) {
|
||||
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||
AnfNodePtrList vir_outputs;
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
|
||||
auto fg_outputs = fg->output();
|
||||
|
@ -137,7 +141,7 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
|
|||
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
|
||||
}
|
||||
|
||||
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) {
|
||||
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -159,6 +163,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
std::unordered_set<AnfNodePtr> *fused_ops) {
|
||||
bool changed = false;
|
||||
auto mng = kernel_graph->manager();
|
||||
|
||||
// depend_prior[depend] = pair(prior, controlDependNode)
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
|
||||
InitDependPrior(todos, &depend_prior);
|
||||
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto node = (*iter)->cast<CNodePtr>();
|
||||
if (node == nullptr) {
|
||||
|
@ -172,7 +181,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
continue;
|
||||
}
|
||||
|
||||
auto fuse_nodes = FindFuseCNodes(node);
|
||||
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
|
||||
if (fuse_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
|
@ -182,11 +191,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
|
|||
AnfNodePtrList inputs;
|
||||
AnfNodePtrList outputs;
|
||||
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
|
||||
RemoveControlDependOut(fg, &outputs, mng);
|
||||
RemoveControlDependOut(fg, &outputs, mng, &depend_prior);
|
||||
ConvertNonscalarTensorToParameter(fg, &inputs);
|
||||
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
|
||||
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
|
||||
|
||||
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fuse_new_node, outputs);
|
||||
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
|
||||
|
||||
// Set graph kernel attr
|
||||
|
|
|
@ -15,13 +15,14 @@
|
|||
*/
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
|
@ -97,15 +98,29 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNod
|
|||
}
|
||||
|
||||
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
||||
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) {
|
||||
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
|
||||
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
|
||||
return false;
|
||||
}
|
||||
circle_nodes->clear();
|
||||
|
||||
auto InputEdges = [&depend_prior](CNodePtr cnode) {
|
||||
std::set<AnfNodePtr> edges;
|
||||
auto range = depend_prior.equal_range(cnode);
|
||||
for (auto iter = range.first; iter != range.second; ++iter) {
|
||||
edges.insert(iter->second.first);
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
for (auto input : inputs) {
|
||||
edges.insert(input);
|
||||
}
|
||||
return edges;
|
||||
};
|
||||
|
||||
std::set<AnfNodePtr> cached_done_set;
|
||||
auto cnode = check_node->cast<CNodePtr>();
|
||||
const auto &inputs = cnode->inputs();
|
||||
const auto &inputs = InputEdges(cnode);
|
||||
// there is a input not in fused_op_set, but the input depends on the fused_op_set
|
||||
for (auto input : inputs) {
|
||||
if (input->isa<CNode>() && !fused_op_set.count(input)) {
|
||||
|
@ -128,7 +143,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
|||
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode_ptr = node->cast<CNodePtr>();
|
||||
for (auto it : cnode_ptr->inputs()) {
|
||||
for (auto it : InputEdges(cnode_ptr)) {
|
||||
if (it->isa<CNode>()) {
|
||||
todos.push_back(it);
|
||||
}
|
||||
|
@ -148,7 +163,9 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
|
|||
return !circle_nodes->empty();
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
|
||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
|
||||
bool is_backward) {
|
||||
std::set<AnfNodePtr> cached_unconnected_set;
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto include = [&fused_op_set](const AnfNodePtr &node) {
|
||||
|
@ -161,7 +178,7 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
|
|||
std::vector<AnfNodePtr> circle_nodes;
|
||||
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
||||
circle_nodes.clear();
|
||||
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes);
|
||||
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
|
||||
// delete the circle node and the node which depend on the circle node in fused op
|
||||
if (has_circle) {
|
||||
auto mng = (*iter)->func_graph()->manager();
|
||||
|
@ -294,7 +311,8 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
|||
lst->assign(res.begin(), res.end());
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
|
||||
auto func_graph = cnode->func_graph();
|
||||
auto mng = func_graph->manager();
|
||||
// Search fusable nodes according input direction.
|
||||
|
@ -307,7 +325,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
|||
|
||||
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
|
||||
if (used_nodes.size() > 1) {
|
||||
used_nodes = RemoveCircle(used_nodes);
|
||||
used_nodes = RemoveCircle(used_nodes, dep_pri);
|
||||
}
|
||||
used_nodes = RemoveWildGetitem(used_nodes);
|
||||
TopoSortForNodeList(&used_nodes);
|
||||
|
@ -316,8 +334,18 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
|
|||
|
||||
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
kernel_graph->set_manager(mng);
|
||||
}
|
||||
auto todos = TopoSort(kernel_graph->get_return());
|
||||
std::reverse(todos.begin(), todos.end());
|
||||
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
|
||||
InitDependPrior(todos, &depend_prior);
|
||||
|
||||
bool changed = false;
|
||||
auto &todos = kernel_graph->execution_order();
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto node = *iter;
|
||||
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
|
||||
|
@ -333,13 +361,16 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph)
|
|||
}
|
||||
}
|
||||
|
||||
auto fuse_nodes = FindFuseCNodes(node);
|
||||
auto fuse_nodes = FindFuseCNodes(node->cast<CNodePtr>(), depend_prior);
|
||||
if (fuse_nodes.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
changed = true;
|
||||
|
||||
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
|
||||
AnfNodePtr fused_new_node;
|
||||
AnfNodePtrList old_outputs;
|
||||
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
|
||||
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
|
|
@ -16,11 +16,13 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <limits>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
|
@ -29,7 +31,9 @@ namespace opt {
|
|||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||
"LambNextMV", "LambUpdateWithLR"};
|
||||
|
||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward = true);
|
||||
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
|
||||
bool is_backward = true);
|
||||
|
||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
||||
|
||||
|
|
|
@ -14,20 +14,24 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
#include <utility>
|
||||
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
#include "pipeline/jit/action.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#if ENABLE_GPU
|
||||
#include "runtime/device/gpu/kernel_info_setter.h"
|
||||
#endif
|
||||
|
@ -526,12 +530,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
|
|||
}
|
||||
}
|
||||
|
||||
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix) {
|
||||
if (fuse_nodes.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const std::string &postfix) {
|
||||
auto mng = kernel_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(kernel_graph, true);
|
||||
|
@ -565,6 +566,8 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
|||
}
|
||||
fuse_op_name += postfix;
|
||||
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
|
||||
|
||||
return std::make_tuple(fuse_new_node, src_outputs);
|
||||
}
|
||||
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||
|
@ -737,7 +740,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
|
|||
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
|
||||
prim::kPrimTranspose, prim::kPrimCast};
|
||||
prim::kPrimTranspose, prim::kPrimCast, prim::kPrimRealDiv};
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
|
||||
|
@ -786,5 +789,123 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
|||
device::gpu::SetKernelInfo(cnode, kernel_type);
|
||||
#endif
|
||||
}
|
||||
|
||||
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto cnode = (*iter)->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto prior_node = cnode->input(kControlDependPriorIndex);
|
||||
auto depend_node = cnode->input(kControlDependBehindIndex);
|
||||
MS_EXCEPTION_IF_NULL(prior_node);
|
||||
MS_EXCEPTION_IF_NULL(depend_node);
|
||||
std::vector<AnfNodePtr> prior_nodes = {prior_node};
|
||||
std::vector<AnfNodePtr> depend_nodes = {depend_node};
|
||||
|
||||
int depend_mode = 0;
|
||||
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
|
||||
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
|
||||
}
|
||||
|
||||
auto GetOutputNodes = [cnode](const AnfNodePtr ¶m) -> std::vector<AnfNodePtr> {
|
||||
std::vector<AnfNodePtr> out_nodes;
|
||||
auto user_set = param->func_graph()->manager()->node_users()[param];
|
||||
for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) {
|
||||
if (iter->first != cnode) {
|
||||
out_nodes.push_back(iter->first);
|
||||
}
|
||||
}
|
||||
return out_nodes;
|
||||
};
|
||||
|
||||
if (prior_node->isa<Parameter>() && depend_mode == 1) {
|
||||
prior_nodes = GetOutputNodes(prior_node);
|
||||
}
|
||||
if (depend_node->isa<Parameter>()) {
|
||||
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> real_prior_nodes;
|
||||
std::set<AnfNodePtr> prior_visited;
|
||||
for (const auto &tmp : prior_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
}
|
||||
prior_visited.clear();
|
||||
std::vector<AnfNodePtr> real_depend_nodes;
|
||||
std::set<AnfNodePtr> depend_visited;
|
||||
for (const auto &tmp : depend_nodes) {
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||
}
|
||||
depend_visited.clear();
|
||||
|
||||
for (auto &prior : real_prior_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
for (auto &depend : real_depend_nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) {
|
||||
continue;
|
||||
}
|
||||
depend_prior->insert({depend, std::make_pair(prior, cnode)});
|
||||
}
|
||||
}
|
||||
real_prior_nodes.clear();
|
||||
real_depend_nodes.clear();
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) {
|
||||
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
||||
if (iter->second.second == control_depend_node) {
|
||||
iter = depend_prior->erase(iter);
|
||||
continue;
|
||||
}
|
||||
++iter;
|
||||
}
|
||||
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior;
|
||||
InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior);
|
||||
for (auto item : new_depend_prior) {
|
||||
depend_prior->insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
|
||||
|
||||
for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) {
|
||||
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) {
|
||||
MS_LOG(ERROR) << "Need real outputs of makeTuple";
|
||||
}
|
||||
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
||||
if (iter->first == outputs[out_idx]) {
|
||||
new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second});
|
||||
iter = depend_prior->erase(iter);
|
||||
continue;
|
||||
}
|
||||
if (iter->second.first == outputs[out_idx]) {
|
||||
new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)});
|
||||
iter = depend_prior->erase(iter);
|
||||
continue;
|
||||
}
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto item : new_fuse_cnode_dep_pri) {
|
||||
depend_prior->insert(item);
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,17 +15,20 @@
|
|||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_set>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
|
||||
#include <nlohmann/json.hpp>
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -48,8 +51,9 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP
|
|||
const AnfNodePtrList &outputs);
|
||||
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
|
||||
const AnfNodePtrList &outputs);
|
||||
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix);
|
||||
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
|
||||
const std::shared_ptr<session::KernelGraph> &kernel_graph,
|
||||
const std::string &postfix);
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
|
||||
std::map<std::string, AnfNodePtr> *address_node_map);
|
||||
|
@ -60,6 +64,12 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
|||
std::vector<PrimitivePtr> GetFusibleOpList();
|
||||
bool IsBasicFuseOp(const AnfNodePtr &node);
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
|
||||
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend);
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
|
||||
|
|
|
@ -1417,5 +1417,46 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
|
|||
}
|
||||
return device_shape;
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||
std::set<AnfNodePtr> *visited) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(visited);
|
||||
if (visited->find(anf_node) != visited->end()) {
|
||||
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
|
||||
return;
|
||||
}
|
||||
visited->insert(anf_node);
|
||||
if (AnfAlgo::IsRealKernel(anf_node)) {
|
||||
result->emplace_back(anf_node);
|
||||
return;
|
||||
}
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
|
||||
}
|
||||
auto input0 = cnode->input(0);
|
||||
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
GetAllFatherRealNode(cnode->input(i), result, visited);
|
||||
}
|
||||
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
|
||||
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
|
||||
if (cnode->inputs().size() != kDependInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
|
||||
}
|
||||
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
|
||||
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
|
||||
}
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -232,6 +232,9 @@ class AnfRuntimeAlgorithm {
|
|||
static bool IsNodeDynamicShape(const AnfNodePtr &node);
|
||||
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
|
||||
// Find control_depend real input nodes.
|
||||
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
|
||||
std::set<AnfNodePtr> *visited);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -725,47 +725,6 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
|
|||
return output_nodes;
|
||||
}
|
||||
|
||||
// Find control_depend real input nodes.
|
||||
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
MS_EXCEPTION_IF_NULL(visited);
|
||||
if (visited->find(anf_node) != visited->end()) {
|
||||
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
|
||||
return;
|
||||
}
|
||||
visited->insert(anf_node);
|
||||
if (AnfAlgo::IsRealKernel(anf_node)) {
|
||||
result->emplace_back(anf_node);
|
||||
return;
|
||||
}
|
||||
if (!anf_node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
|
||||
}
|
||||
auto input0 = cnode->input(0);
|
||||
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
|
||||
GetAllFatherRealNode(cnode->input(i), result, visited);
|
||||
}
|
||||
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
|
||||
if (cnode->inputs().size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
|
||||
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
|
||||
if (cnode->inputs().size() != kDependInputSize) {
|
||||
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
|
||||
}
|
||||
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
|
||||
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
|
||||
}
|
||||
}
|
||||
|
||||
// update the depend relations of control depend
|
||||
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
|
||||
for (const auto &node : depends) {
|
||||
|
@ -800,12 +759,12 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
|
|||
std::vector<AnfNodePtr> real_prior_nodes;
|
||||
std::set<AnfNodePtr> prior_visited;
|
||||
for (const auto &tmp : prior_nodes) {
|
||||
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
|
||||
}
|
||||
std::vector<AnfNodePtr> real_depend_nodes;
|
||||
std::set<AnfNodePtr> depend_visited;
|
||||
for (const auto &tmp : depend_nodes) {
|
||||
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
|
||||
}
|
||||
UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue