consider controldepend edges in checkcircle

This commit is contained in:
lingyunli63 2020-11-19 20:22:56 +08:00
parent 534cc9bbe9
commit e6a5fc0739
8 changed files with 264 additions and 86 deletions

View File

@ -47,12 +47,13 @@ IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNode
return is_fusable ? FOLLOW : EXCLUDE;
}
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
// Search fusable nodes according input direction.
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes, false);
used_nodes = RemoveCircle(used_nodes, dep_pri, false);
}
TopoSortForNodeList(&used_nodes);
return used_nodes;
@ -78,7 +79,8 @@ void SearchForDependNode(const AnfNodeSet &outputs_set, const AnfNodeIndexSet &u
}
bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_outputs, const FuncGraphManagerPtr &mng,
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv) {
std::unordered_map<AnfNodePtr, AnfNodePtr> *eqv,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
AnfNodeSet outputs_set;
for (auto out : *outputs) {
outputs_set.insert(out);
@ -112,6 +114,7 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
auto new_control_depend = control_depend_node->func_graph()->NewCNode(new_control_depend_inputs);
mng->Replace(control_depend_node, new_control_depend);
has_erase_outs = true;
UpdateControlDependNode(depend_prior, control_depend_node, new_control_depend);
}
} else {
it++;
@ -120,7 +123,8 @@ bool FindControlDependOut(AnfNodePtrList *outputs, const AnfNodePtrList &vir_out
return has_erase_outs;
}
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng) {
void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, const FuncGraphManagerPtr &mng,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
AnfNodePtrList vir_outputs;
std::unordered_map<AnfNodePtr, AnfNodePtr> eqv;
auto fg_outputs = fg->output();
@ -137,7 +141,7 @@ void RemoveControlDependOut(const FuncGraphPtr &fg, AnfNodePtrList *outputs, con
MS_LOG(EXCEPTION) << "The size of virtual output of the fg is not the same with the real output";
}
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv)) {
if (!FindControlDependOut(outputs, vir_outputs, mng, &eqv, depend_prior)) {
return;
}
@ -159,6 +163,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
std::unordered_set<AnfNodePtr> *fused_ops) {
bool changed = false;
auto mng = kernel_graph->manager();
// depend_prior[depend] = pair(prior, controlDependNode)
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
InitDependPrior(todos, &depend_prior);
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>();
if (node == nullptr) {
@ -172,7 +181,7 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
continue;
}
auto fuse_nodes = FindFuseCNodes(node);
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
if (fuse_nodes.size() <= 1) {
continue;
}
@ -182,11 +191,11 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr
AnfNodePtrList inputs;
AnfNodePtrList outputs;
std::tie(fg, inputs, outputs) = compile::TransformSegmentToAnfGraph(fuse_nodes);
RemoveControlDependOut(fg, &outputs, mng);
RemoveControlDependOut(fg, &outputs, mng, &depend_prior);
ConvertNonscalarTensorToParameter(fg, &inputs);
auto fuse_new_node = CreateNewFuseCNode(kernel_graph, fg, inputs, outputs);
SetNewKernelInfo(fuse_new_node, fg, inputs, outputs, AnfAlgo::GetProcessor(fuse_nodes[0]));
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fuse_new_node, outputs);
ReplaceNewFuseCNode(kernel_graph, fuse_new_node, outputs);
// Set graph kernel attr

View File

@ -15,13 +15,14 @@
*/
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include <memory>
#include <string>
#include <algorithm>
#include <unordered_set>
#include <map>
#include <set>
#include <memory>
#include <queue>
#include <string>
#include <set>
#include <unordered_set>
#include <utility>
#include <vector>
#include "frontend/operator/ops.h"
@ -97,15 +98,29 @@ IncludeType IncludeFusedBasicOpBackward(const AnfNodePtr &cur_node, const AnfNod
}
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes) {
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
return false;
}
circle_nodes->clear();
auto InputEdges = [&depend_prior](CNodePtr cnode) {
std::set<AnfNodePtr> edges;
auto range = depend_prior.equal_range(cnode);
for (auto iter = range.first; iter != range.second; ++iter) {
edges.insert(iter->second.first);
}
auto inputs = cnode->inputs();
for (auto input : inputs) {
edges.insert(input);
}
return edges;
};
std::set<AnfNodePtr> cached_done_set;
auto cnode = check_node->cast<CNodePtr>();
const auto &inputs = cnode->inputs();
const auto &inputs = InputEdges(cnode);
// there is a input not in fused_op_set, but the input depends on the fused_op_set
for (auto input : inputs) {
if (input->isa<CNode>() && !fused_op_set.count(input)) {
@ -128,7 +143,7 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
if (node->isa<CNode>()) {
auto cnode_ptr = node->cast<CNodePtr>();
for (auto it : cnode_ptr->inputs()) {
for (auto it : InputEdges(cnode_ptr)) {
if (it->isa<CNode>()) {
todos.push_back(it);
}
@ -148,7 +163,9 @@ bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &che
return !circle_nodes->empty();
}
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward) {
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
bool is_backward) {
std::set<AnfNodePtr> cached_unconnected_set;
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto include = [&fused_op_set](const AnfNodePtr &node) {
@ -161,7 +178,7 @@ std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bo
std::vector<AnfNodePtr> circle_nodes;
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
circle_nodes.clear();
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes);
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
// delete the circle node and the node which depend on the circle node in fused op
if (has_circle) {
auto mng = (*iter)->func_graph()->manager();
@ -294,7 +311,8 @@ void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
lst->assign(res.begin(), res.end());
}
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
auto func_graph = cnode->func_graph();
auto mng = func_graph->manager();
// Search fusable nodes according input direction.
@ -307,7 +325,7 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
used_nodes.insert(used_nodes.end(), user_nodes.begin() + 1, user_nodes.end());
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes);
used_nodes = RemoveCircle(used_nodes, dep_pri);
}
used_nodes = RemoveWildGetitem(used_nodes);
TopoSortForNodeList(&used_nodes);
@ -316,8 +334,18 @@ std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode) {
bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
kernel_graph->set_manager(mng);
}
auto todos = TopoSort(kernel_graph->get_return());
std::reverse(todos.begin(), todos.end());
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
InitDependPrior(todos, &depend_prior);
bool changed = false;
auto &todos = kernel_graph->execution_order();
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = *iter;
if (!AnfAlgo::IsGraphKernel(node) || !kernel_graph->nodes().contains(node)) {
@ -333,13 +361,16 @@ bool FuseCompositeOps(const std::shared_ptr<session::KernelGraph> &kernel_graph)
}
}
auto fuse_nodes = FindFuseCNodes(node);
auto fuse_nodes = FindFuseCNodes(node->cast<CNodePtr>(), depend_prior);
if (fuse_nodes.size() <= 1) {
continue;
}
changed = true;
FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
AnfNodePtr fused_new_node;
AnfNodePtrList old_outputs;
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "");
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
}
return changed;
}

View File

@ -16,11 +16,13 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#include <limits>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include <memory>
#include <limits>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
@ -29,7 +31,9 @@ namespace opt {
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"};
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op, bool is_backward = true);
std::vector<AnfNodePtr> RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior,
bool is_backward = true);
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);

View File

@ -14,20 +14,24 @@
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include <map>
#include <set>
#include <tuple>
#include <unordered_set>
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/action.h"
#include <utility>
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "vm/segment_runner.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_decoder.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
#include "ir/func_graph_cloner.h"
#include "ir/func_graph.h"
#include "backend/optimizer/pass/const_input_to_attr_registry.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/action.h"
#include "vm/segment_runner.h"
#if ENABLE_GPU
#include "runtime/device/gpu/kernel_info_setter.h"
#endif
@ -526,12 +530,9 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
}
}
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix) {
if (fuse_nodes.empty()) {
return;
}
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph,
const std::string &postfix) {
auto mng = kernel_graph->manager();
if (mng == nullptr) {
mng = Manage(kernel_graph, true);
@ -565,6 +566,8 @@ void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
}
fuse_op_name += postfix;
fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(fuse_op_name));
return std::make_tuple(fuse_new_node, src_outputs);
}
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
@ -737,7 +740,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape,
prim::kPrimTranspose, prim::kPrimCast};
prim::kPrimTranspose, prim::kPrimCast, prim::kPrimRealDiv};
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
@ -786,5 +789,123 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
device::gpu::SetKernelInfo(cnode, kernel_type);
#endif
}
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior) {
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto cnode = (*iter)->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
if (!AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimControlDepend)) {
continue;
}
auto prior_node = cnode->input(kControlDependPriorIndex);
auto depend_node = cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node);
std::vector<AnfNodePtr> prior_nodes = {prior_node};
std::vector<AnfNodePtr> depend_nodes = {depend_node};
int depend_mode = 0;
if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) {
depend_mode = AnfAlgo::GetNodeAttr<int64_t>(cnode, kControlDependMode);
}
auto GetOutputNodes = [cnode](const AnfNodePtr &param) -> std::vector<AnfNodePtr> {
std::vector<AnfNodePtr> out_nodes;
auto user_set = param->func_graph()->manager()->node_users()[param];
for (auto iter = user_set.cbegin(); iter != user_set.cend(); ++iter) {
if (iter->first != cnode) {
out_nodes.push_back(iter->first);
}
}
return out_nodes;
};
if (prior_node->isa<Parameter>() && depend_mode == 1) {
prior_nodes = GetOutputNodes(prior_node);
}
if (depend_node->isa<Parameter>()) {
depend_nodes = depend_mode == 1 ? GetOutputNodes(depend_node) : std::vector<AnfNodePtr>{};
}
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
prior_visited.clear();
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
depend_visited.clear();
for (auto &prior : real_prior_nodes) {
if (AnfAlgo::CheckPrimitiveType(prior, prim::kPrimControlDepend)) {
continue;
}
for (auto &depend : real_depend_nodes) {
if (AnfAlgo::CheckPrimitiveType(depend, prim::kPrimControlDepend)) {
continue;
}
depend_prior->insert({depend, std::make_pair(prior, cnode)});
}
}
real_prior_nodes.clear();
real_depend_nodes.clear();
}
}
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend) {
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
if (iter->second.second == control_depend_node) {
iter = depend_prior->erase(iter);
continue;
}
++iter;
}
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_depend_prior;
InitDependPrior(std::vector<AnfNodePtr>{new_control_depend}, &new_depend_prior);
for (auto item : new_depend_prior) {
depend_prior->insert(item);
}
}
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) {
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) {
MS_LOG(ERROR) << "Need real outputs of makeTuple";
}
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) {
continue;
}
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
if (iter->first == outputs[out_idx]) {
new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second});
iter = depend_prior->erase(iter);
continue;
}
if (iter->second.first == outputs[out_idx]) {
new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)});
iter = depend_prior->erase(iter);
continue;
}
++iter;
}
}
for (auto item : new_fuse_cnode_dep_pri) {
depend_prior->insert(item);
}
}
} // namespace opt
} // namespace mindspore

View File

@ -15,17 +15,20 @@
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_
#include <string>
#include <vector>
#include <memory>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_set>
#include <nlohmann/json.hpp>
#include <utility>
#include <vector>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "backend/session/kernel_graph.h"
#include "backend/kernel_compiler/akg/akg_kernel_json_generator.h"
#include <nlohmann/json.hpp>
namespace mindspore {
namespace opt {
@ -48,8 +51,9 @@ AnfNodePtr CreateNewFuseCNode(const FuncGraphPtr &kernel_graph, const FuncGraphP
const AnfNodePtrList &outputs);
void ReplaceNewFuseCNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &new_fuse_cnode,
const AnfNodePtrList &outputs);
void FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &postfix);
std::tuple<AnfNodePtr, AnfNodePtrList> FuseNodesToSubGraph(const std::vector<AnfNodePtr> &fuse_nodes,
const std::shared_ptr<session::KernelGraph> &kernel_graph,
const std::string &postfix);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc);
bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, nlohmann::json *op_desc,
std::map<std::string, AnfNodePtr> *address_node_map);
@ -60,6 +64,12 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
std::vector<PrimitivePtr> GetFusibleOpList();
bool IsBasicFuseOp(const AnfNodePtr &node);
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void InitDependPrior(const std::vector<AnfNodePtr> &todos,
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior);
void UpdateControlDependNode(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &control_depend_node, const AnfNodePtr &new_control_depend);
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_HELPER_H_

View File

@ -1417,5 +1417,46 @@ std::vector<size_t> AnfRuntimeAlgorithm::GetOutputRealDeviceShapeIfExist(const A
}
return device_shape;
}
void AnfRuntimeAlgorithm::GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}
} // namespace session
} // namespace mindspore

View File

@ -232,6 +232,9 @@ class AnfRuntimeAlgorithm {
static bool IsNodeDynamicShape(const AnfNodePtr &node);
static std::vector<size_t> GetInputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
static std::vector<size_t> GetOutputRealDeviceShapeIfExist(const AnfNodePtr &anf_node, size_t index);
// Find control_depend real input nodes.
static void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result,
std::set<AnfNodePtr> *visited);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -725,47 +725,6 @@ std::vector<AnfNodePtr> KernelGraph::GetOutputNodes(const AnfNodePtr &node) {
return output_nodes;
}
// Find control_depend real input nodes.
void GetAllFatherRealNode(const AnfNodePtr &anf_node, std::vector<AnfNodePtr> *result, std::set<AnfNodePtr> *visited) {
MS_EXCEPTION_IF_NULL(anf_node);
MS_EXCEPTION_IF_NULL(result);
MS_EXCEPTION_IF_NULL(visited);
if (visited->find(anf_node) != visited->end()) {
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " has alreday been visited";
return;
}
visited->insert(anf_node);
if (AnfAlgo::IsRealKernel(anf_node)) {
result->emplace_back(anf_node);
return;
}
if (!anf_node->isa<CNode>()) {
return;
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << anf_node->DebugString();
}
auto input0 = cnode->input(0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
GetAllFatherRealNode(cnode->input(i), result, visited);
}
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputNodeIndexInTupleGetItem), result, visited);
} else if (IsPrimitive(input0, prim::kPrimDepend)) {
if (cnode->inputs().size() != kDependInputSize) {
MS_LOG(EXCEPTION) << "Depend node must have 2 inputs!";
}
GetAllFatherRealNode(cnode->input(kRealInputIndexInDepend), result, visited);
GetAllFatherRealNode(cnode->input(kDependAttachNodeIndex), result, visited);
}
}
// update the depend relations of control depend
void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &depends) {
for (const auto &node : depends) {
@ -800,12 +759,12 @@ void KernelGraph::UpdateControlDependRelations(const std::vector<AnfNodePtr> &de
std::vector<AnfNodePtr> real_prior_nodes;
std::set<AnfNodePtr> prior_visited;
for (const auto &tmp : prior_nodes) {
GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited);
}
std::vector<AnfNodePtr> real_depend_nodes;
std::set<AnfNodePtr> depend_visited;
for (const auto &tmp : depend_nodes) {
GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
AnfAlgo::GetAllFatherRealNode(tmp, &real_depend_nodes, &depend_visited);
}
UpdateNodeInputOutputEdges(real_prior_nodes, real_depend_nodes);
}