diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index e13b34f9d8f..dec0059d4a6 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1269,15 +1269,17 @@ std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNode } else { AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; for (auto ¶m_pair : param_sub_set) { - CNodePtr graph_cnode = param_pair.first->cast(); - if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { + CNodePtr param_cnode = param_pair.first->cast(); + AnfNodePtr graph_value_node; + if (param_cnode->input(0)->isa()) { + graph_value_node = param_cnode->input(0)->cast()->input(1); + } else { + graph_value_node = param_cnode->input(0); + } + if (!IsValueNode(graph_value_node)) { continue; } - CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast(); - if (!IsValueNode(graph_cnode_inp0->input(1))) { - continue; - } - FuncGraphPtr graph_sub = GetValueNode(graph_cnode_inp0->input(1)); + FuncGraphPtr graph_sub = GetValueNode(graph_value_node); auto parameters = graph_sub->parameters(); if (IntToSize(param_pair.second - 1) >= parameters.size()) { MS_LOG(EXCEPTION) << "The index is out of range, index is " << param_pair.second - 1 << ", vector size is " @@ -1864,7 +1866,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { // return -> make_tuple if (current_prim->name() == MAKE_TUPLE) { - MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; + MS_LOG(WARNING) << "The loss have make_tuple, it is not supported"; + return nullptr; } // return -> loss @@ -2069,6 +2072,12 @@ std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no auto graph = GetValueNode(cnode->input(1)); MS_LOG(DEBUG) << "Find the forward graph success"; graph_set.insert(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto graph_used = manager->func_graphs_used_total(graph); + for (auto &sub_graph : graph_used) { + graph_set.insert(sub_graph); + } } } return graph_set; @@ -2423,7 +2432,7 @@ void HandleRootReshape(const std::vector &all_nodes) { void MarkForwardCNode(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); auto all_nodes = root->nodes(); - std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + auto graph_set = FindForwardGraphByRootNodes(all_nodes); if (graph_set.empty()) { MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index a9a4d941b25..804d470afad 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -145,10 +145,10 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode); Status ParallelInit(); -std::vector ExtractInputsTensorName(const CNodePtr &node); - std::set ForwardGraph(const FuncGraphPtr &root); +std::vector ExtractInputsTensorName(const CNodePtr &node); + bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name); using RefKeyPair = std::pair>; diff --git a/tests/ut/python/parallel/test_auto_parallel_multi_graph.py b/tests/ut/python/parallel/test_auto_parallel_multi_graph.py new file mode 100644 index 00000000000..f510fdedebf --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_multi_graph.py @@ -0,0 +1,69 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore.common.api import _executor +from mindspore import Tensor, Parameter +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class TwoInputBprop(Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + def construct(self, x, y): + return self.op(x, y) + + def bprop(self, x, y, out, dout): + return x * 5, y * 8 + + +class ParallelFloorDivBpropNet(Cell): + def __init__(self, mul_size, test_size, strategy=None, strategy2=None): + super().__init__() + mul_np = np.full(mul_size, 0.5, dtype=np.float32) + floordiv_np = np.full(test_size, 0.1, dtype=np.float32) + self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") + self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight") + self.mul = TwoInputBprop() + self.floor_div = P.FloorDiv() + if strategy is not None: + self.mul.op.shard(strategy2) + self.floor_div.shard(strategy) + + def construct(self, inputs, label): + x = self.mul(inputs, self.mul_weight) + x = self.floor_div(x, self.floordiv_weight) + return x + + +inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) +label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, inputs_, label_) + context.reset_auto_parallel_context() + + +def test_net(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) + strategy = ((4, 1), (4, 1)) + net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy) + compile_net(net)