From 4c34867486b6e56435ffdc63e44531bf579542e9 Mon Sep 17 00:00:00 2001 From: huanghui Date: Thu, 16 Jul 2020 16:59:11 +0800 Subject: [PATCH] control optimizie for heterogeneous excutor --- .../ccsrc/backend/session/session_basic.cc | 6 ++ mindspore/ccsrc/vm/segment_runner.cc | 8 ++ mindspore/ccsrc/vm/transform.cc | 96 ++++++++++++++++++- mindspore/core/ir/anf.cc | 10 ++ .../st/heterogeneous_excutor/test_control.py | 71 ++++++++++++++ 5 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 tests/st/heterogeneous_excutor/test_control.py diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 91939d493a8..8b8615979ea 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -451,10 +451,14 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K } auto origin_inputs = cnode->inputs(); bool optimize_depend = false; + bool optimize_control_depend = false; if (IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() == 3 && origin_inputs[kRealInputIndexInDepend]->isa()) { optimize_depend = true; } + if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3) { + optimize_control_depend = true; + } // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { auto anf = origin_inputs[input_idx]; @@ -485,6 +489,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K } else if (optimize_depend && input_idx == kDependAttachNodeIndex) { cnode_inputs.push_back(origin_inputs[kRealInputIndexInDepend]); continue; + } else if (optimize_control_depend) { + cnode_inputs.push_back(NewValueNode(MakeValue(input_idx))); } else { *from_other_graph = true; // the input node is a cnode from other graph diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index 540b77bcaf1..151c20a5355 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -117,6 +117,14 @@ std::tuple TransformSegmentToAnfGr eqv.find(inps[kDependAttachNodeIndex]) == eqv.end()) { args.emplace_back(inps[kRealInputIndexInDepend]); args.emplace_back(inps[kRealInputIndexInDepend]); + } else if (IsPrimitive(fn, prim::kPrimControlDepend) && inps.size() == 3) { + for (size_t i = 1; i < inps.size(); ++i) { + if (inps[i]->isa() && std::find(lst.begin(), lst.end(), inps[i]) == lst.end()) { + args.emplace_back(NewValueNode(MakeValue(i))); + } else { + args.emplace_back(ref(inps[i])); + } + } } else { (void)std::transform(std::begin(inps) + 1, std::end(inps), std::back_inserter(args), ref); } diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 2cf6ead8130..0b96f2feb90 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -69,7 +69,91 @@ bool ContainMultiTarget(const std::vector &nodes) { return false; } -void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref) { +bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, + std::vector *prior_nodes, std::vector *depend_nodes) { + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(behind_node); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + if (prior_node->isa()) { + for (auto &user : node_users[prior_node]) { + auto cnode = user.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + prior_nodes->emplace_back(cnode); + } + } + } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { + prior_nodes->emplace_back(prior_node); + } else { + return false; + } + if (behind_node->isa()) { + for (auto &user : node_users[behind_node]) { + auto cnode = user.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + depend_nodes->emplace_back(cnode); + } + } + } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { + depend_nodes->emplace_back(behind_node); + } else { + return false; + } + return true; +} + +void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, + std::map> *control_edges, + std::map *nodes_ref) { + MS_EXCEPTION_IF_NULL(node); + auto input_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto prior_node = input_cnode->input(kControlDependPriorIndex); + auto depend_node = input_cnode->input(kControlDependBehindIndex); + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(depend_node); + PrimitivePtr prim_ptr = GetValueNode(input_cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim_ptr); + ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); + int depend_mode = 0; + if (mode_ptr != nullptr) { + depend_mode = GetValue(mode_ptr); + } + if ((prior_node->isa() || depend_node->isa()) && depend_mode == 0) { + return; + } + std::vector prior_nodes; + std::vector behind_nodes; + if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { + return; + } + for (auto &first_node : prior_nodes) { + for (auto &second_node : behind_nodes) { + MS_EXCEPTION_IF_NULL(first_node); + MS_EXCEPTION_IF_NULL(second_node); + auto iter = control_edges->find(second_node); + if (iter == control_edges->end()) { + (void)control_edges->insert( + std::pair>(second_node, std::vector{first_node})); + } else { + iter->second.emplace_back(first_node); + } + auto ref_iter = nodes_ref->find(first_node); + if (ref_iter != nodes_ref->end()) { + ref_iter->second++; + } else { + (void)nodes_ref->insert(std::pair(first_node, 1)); + } + } + } +} + +void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref, + std::map> *control_edges) { std::queue queue; queue.push(graph->get_return()); std::set visited; @@ -83,6 +167,9 @@ void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *n auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); for (auto &input : cnode->inputs()) { + if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { + AddControlEdge(graph, input, control_edges, nodes_ref); + } auto iter = nodes_ref->find(input); if (iter != nodes_ref->end()) { iter->second++; @@ -142,7 +229,8 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & std::stack to_visit; std::stack next_to_visit; std::map nodes_ref; - CalcNodeRefCount(graph, &nodes_ref); + std::map> control_edges; + CalcNodeRefCount(graph, &nodes_ref, &control_edges); std::string handle_target = default_target; std::string next_target = ""; to_visit.push(graph->get_return()); @@ -162,6 +250,10 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & MS_EXCEPTION_IF_NULL(cnode); auto node_inputs = cnode->inputs(); std::reverse(node_inputs.begin(), node_inputs.end()); + auto ctrl_inputs = control_edges.find(node); + if (ctrl_inputs != control_edges.end()) { + node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); + } for (auto &input : node_inputs) { auto iter = nodes_ref.find(input); if (iter != nodes_ref.end()) { diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index e238012b141..10ccc31050b 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -26,6 +26,7 @@ #include "ir/func_graph.h" #include "ir/primitive.h" #include "utils/context/ms_context.h" +#include "base/core_ops.h" namespace mindspore { // namespace to support intermediate representation definition @@ -217,6 +218,15 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { auto primitive = value->cast(); auto att_target = primitive->GetAttr("primitive_target"); if (att_target != nullptr) { + if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || + IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || + IsPrimitive(attr_input, prim::kPrimMakeTuple) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || + IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimTupleGetItem) || + IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || + IsPrimitive(attr_input, prim::kPrimPartial)) { + primitive->EraseAttr("primitive_target"); + return default_target; + } if (!att_target->isa()) { MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target"; } diff --git a/tests/st/heterogeneous_excutor/test_control.py b/tests/st/heterogeneous_excutor/test_control.py new file mode 100644 index 00000000000..189441f1f99 --- /dev/null +++ b/tests/st/heterogeneous_excutor/test_control.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net1(nn.Cell): + def __init__(self): + super(Net1, self).__init__() + self.relu1 = P.ReLU() + self.relu2 = P.ReLU() + self.mul = P.Mul() + self.control = P.ControlDepend() + + def construct(self, x, y): + a = self.relu1(x) + b = self.relu2(y) + c = self.mul(a, b) + e = self.control(a, b) + return c, e + + +class Net2(nn.Cell): + def __init__(self): + super(Net2, self).__init__() + self.relu1 = P.ReLU() + self.relu2 = P.ReLU().add_prim_attr("primitive_target", "CPU") + self.mul = P.Mul() + self.control = P.ControlDepend() + + def construct(self, x, y): + a = self.relu1(x) + b = self.relu2(y) + c = self.mul(a, b) + e = self.control(a, b) + return c, e + + +def test_net(): + x = np.random.randn(2, 3, 3, 4).astype(np.float32) + y = np.random.randn(2, 3, 3, 4).astype(np.float32) + net1 = Net1() + output1 = net1(Tensor(x), Tensor(y)) + + context.set_context(save_graphs=True) + net2 = Net2() + output2 = net2(Tensor(x), Tensor(y)) + assert np.allclose(output1[0].asnumpy(), output2[0].asnumpy()) + print("##success##") + + +if __name__ == "__main__": + test_net()