From d4bba3f1d200cb70921595b58cb76e0c7ea83b1b Mon Sep 17 00:00:00 2001 From: lichenever Date: Fri, 18 Sep 2020 19:22:59 +0800 Subject: [PATCH] fix_auto_parallel_find_loss_bug --- .../ccsrc/frontend/parallel/step_parallel.cc | 121 ++++++------------ .../ccsrc/frontend/parallel/step_parallel.h | 3 +- mindspore/train/model.py | 3 - tests/ut/python/parallel/test_mul_div_bn.py | 78 +++++++++++ 4 files changed, 119 insertions(+), 86 deletions(-) create mode 100644 tests/ut/python/parallel/test_mul_div_bn.py diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index ea09e4786e..a3e80d31e8 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -711,61 +711,6 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { return tuple_index_value->cast()->value(); } -// Judge whether the node is a loss, and if there are multiple outputs, -// get which output is a grad according to the tuple getitem. -// Currently, it is not supported that the sens is a tuple. -LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { - MS_EXCEPTION_IF_NULL(loss_node); - FuncGraphPtr sub_graph = loss_node->func_graph(); - MS_EXCEPTION_IF_NULL(sub_graph); - CNodePtr return_node = sub_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - LossNodeInfo node_info; - - // return -> cast - auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - auto pre_prim = GetValueNode(pre_cnode->input(0)); - if (pre_prim->name() == CAST && !pre_cnode->has_user_data()) { - pre_node = pre_cnode->input(1); - } - - // return -> loss - if (pre_node == loss_node) { - node_info.has_tuple_getitem = false; - node_info.dout_index = 0; - return node_info; - } - - // return -> tuple_getitem -> loss - auto cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto current_value = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - // size of common cnode is larger than 1 - if (cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2"; - } - - if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) { - // size of tuple_getitem cnode is 3 - auto tuple_index = GetTupleGetItemIndex(cnode); - node_info.has_tuple_getitem = true; - node_info.dout_index = tuple_index; - return node_info; - } - - MS_LOG(EXCEPTION) << "Invalid loss"; -} - void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); @@ -958,13 +903,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { } void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, - const std::vector> &sens_loss_pairs) { + const std::vector> &sens_loss_pairs) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(node); bool is_loss_cnode = std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), - [node](const std::pair &element) { return element.second == node; }); + [node](const std::pair &element) { return element.second.loss_node == node; }); MirrorOps mirror_ops = distribute_operator->mirror_ops(); VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); @@ -1819,7 +1764,20 @@ void ReshapeInit(const std::vector &all_nodes) { } } -CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { +CNodePtr HandleDependLoss(const CNodePtr &cnode) { + // Handle return->depend->loss + auto prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == DEPEND) { + auto depend_before = cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(depend_before); + return HandleDependLoss(depend_before); + } + return cnode; +} + +LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) { + LossNodeInfo loss_node_info; MS_EXCEPTION_IF_NULL(func_graph); CNodePtr return_node = func_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); @@ -1831,9 +1789,9 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { auto pre_cnode = pre_node->cast(); if (pre_cnode == nullptr) { - return nullptr; + return loss_node_info; } - + pre_cnode = HandleDependLoss(pre_cnode); auto current_prim = GetValueNode(pre_cnode->input(0)); // return -> cast if (current_prim->name() == CAST && !pre_cnode->has_user_data()) { @@ -1845,7 +1803,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { // notice: the GetNext op has not input if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { MS_LOG(INFO) << "The loss is: " << current_prim->name(); - return pre_cnode; + loss_node_info.loss_node = pre_cnode; + return loss_node_info; } // size of common cnode is larger than 1 @@ -1855,36 +1814,34 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { // return -> tuple_getitem -> loss if (current_prim->name() == TUPLE_GETITEM) { + auto tuple_index = GetTupleGetItemIndex(pre_cnode); AnfNodePtr pre_pre_node = pre_cnode->input(1); MS_EXCEPTION_IF_NULL(pre_pre_node); auto pre_pre_cnode = pre_pre_node->cast(); - auto value = pre_pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value); - PrimitivePtr prim = value->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - MS_LOG(DEBUG) << "The loss name is " << prim->name(); - return pre_pre_cnode; + loss_node_info.has_tuple_getitem = true; + loss_node_info.dout_index = tuple_index; + loss_node_info.loss_node = pre_pre_cnode; + return loss_node_info; } // return -> make_tuple if (current_prim->name() == MAKE_TUPLE) { MS_LOG(WARNING) << "The loss have make_tuple, it is not supported"; - return nullptr; + return loss_node_info; } // return -> loss + loss_node_info.loss_node = pre_cnode; MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); - return pre_cnode; + return loss_node_info; } -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { +TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) { TensorLayouts ret; + auto loss_cnode = node_info.loss_node; MS_EXCEPTION_IF_NULL(loss_cnode); - AnfNodePtr node = loss_cnode->cast(); - MS_EXCEPTION_IF_NULL(node); - LossNodeInfo node_info = GetLossNodeInfo(node); ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(prim_anf_node); PrimitivePtr prim = prim_anf_node->value()->cast(); @@ -2086,9 +2043,9 @@ std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no return graph_set; } -void StepSplitSens(const std::pair &sens_loss_pair) { +void StepSplitSens(const std::pair &sens_loss_pair) { CNodePtr sens_node = sens_loss_pair.first; - CNodePtr loss_node = sens_loss_pair.second; + auto loss_node = sens_loss_pair.second; auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); if (!loss_grad_layout.empty()) { SplitSens(sens_node, loss_grad_layout[0]); @@ -2096,9 +2053,9 @@ void StepSplitSens(const std::pair &sens_loss_pair) { } // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -std::vector> GetSensLossPairs(const FuncGraphPtr &root) { +std::vector> GetSensLossPairs(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); - std::vector> sens_loss_pairs; + std::vector> sens_loss_pairs; for (auto &node : root->nodes()) { if (!node->isa()) { continue; @@ -2140,12 +2097,12 @@ std::vector> GetSensLossPairs(const FuncGraphPtr & MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; } auto func_graph = GetValueNode(expect_j_cnode->input(1)); - auto loss_cnode = FindLossCNode(func_graph); - if (loss_cnode == nullptr) { + auto loss_node_info = FindLossCNode(func_graph); + if (loss_node_info.loss_node == nullptr) { MS_LOG(WARNING) << "Can not find the loss cnode"; continue; } - std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); + std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_node_info); sens_loss_pairs.push_back(sens_loss_pair); } return sens_loss_pairs; @@ -2157,7 +2114,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector> sens_loss_pairs = GetSensLossPairs(root); + std::vector> sens_loss_pairs = GetSensLossPairs(root); bool has_backward = !sens_loss_pairs.empty(); // split sens must before inserting the operators. for (auto &pair : sens_loss_pairs) { @@ -2372,7 +2329,7 @@ std::set ForwardGraph(const FuncGraphPtr &root) { std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { MS_EXCEPTION_IF_NULL(graph); std::vector root_forward_nodes; - auto loss_cnode = FindLossCNode(graph); + auto loss_cnode = FindLossCNode(graph).loss_node; if (loss_cnode == nullptr) { MS_LOG(WARNING) << "Can not find the loss cnode"; return root_forward_nodes; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 804d470afa..e36dea6ab1 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -39,6 +39,7 @@ const uint64_t kUSecondInSecond = 1000000; struct LossNodeInfo { bool has_tuple_getitem = false; int dout_index = 0; // now don't support the sens is a tuple + CNodePtr loss_node = nullptr; }; std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); @@ -82,7 +83,7 @@ std::pair FindCNode(const AnfNodePtr &anode, const std::string & void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, - const std::vector> &sens_loss_pairs); + const std::vector> &sens_loss_pairs); // Generate and init parallel operator OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 6c1eaa64bd..e984743f2a 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -32,7 +32,6 @@ from ..nn.metrics import Loss from .. import nn from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..context import ParallelMode -from ..parallel._utils import _need_to_full, _to_full_tensor from ..parallel._cost_model_context import _set_multi_subgraphs from .dataset_helper import DatasetHelper, connect_network_with_dataset from . import amp @@ -436,8 +435,6 @@ class Model: # for data sink dataset_helper only iter once, other wise iter epoch_size times. for inputs in dataset_helper: - if _need_to_full() and context.get_context("device_target") == "GPU": - inputs = _to_full_tensor(inputs, self._device_number, self._global_rank) cb_params.train_dataset_element = inputs list_callback.step_begin(run_context) outputs = self._train_network(*inputs) diff --git a/tests/ut/python/parallel/test_mul_div_bn.py b/tests/ut/python/parallel/test_mul_div_bn.py new file mode 100644 index 0000000000..9254ae9a18 --- /dev/null +++ b/tests/ut/python/parallel/test_mul_div_bn.py @@ -0,0 +1,78 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import mindspore as ms +import mindspore.context as context +from mindspore.common.api import _executor +from mindspore import Tensor, Parameter +import mindspore.nn as nn +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class TwoInputBpropOperator(Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + self.bp = P.TensorAdd() + + def construct(self, x, y): + return self.op(x, y) + + def bprop(self, x, y, out, dout): + return self.bp(5, x), self.bp(y, 8) + + +class ParallelFloorDivBpropNet(Cell): + def __init__(self, mul_size, test_size, strategy=None, strategy2=None): + super().__init__() + mul_np = np.full(mul_size, 0.5, dtype=np.float32) + floordiv_np = np.full(test_size, 0.1, dtype=np.float32) + self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight") + self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight") + self.mul = TwoInputBpropOperator() + self.floor_div = P.FloorDiv() + self.bn = nn.BatchNorm1d(num_features=96) + if strategy is not None: + self.mul.op.shard(strategy2) + self.mul.bp.shard(strategy2) + self.floor_div.shard(strategy) + + def construct(self, inputs, label): + x = self.mul(inputs, self.mul_weight) + x = self.floor_div(x, self.floordiv_weight) + x = self.bn(x) + return x + + +inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) +label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32) + + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + train_net.set_train() + _executor.compile(train_net, inputs_, label_) + context.reset_auto_parallel_context() + + +def test_net(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0) + strategy = ((4, 1), (4, 1)) + net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy) + compile_net(net)