!6502 [AutoParallel]Fix auto parallel find loss bug

Merge pull request !6502 from lichen/fix_auto_parallel_find_loss_bug
This commit is contained in:
mindspore-ci-bot 2020-09-19 15:56:56 +08:00 committed by Gitee
commit 5a20b11012
4 changed files with 119 additions and 86 deletions

View File

@ -711,61 +711,6 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
return tuple_index_value->cast<Int32ImmPtr>()->value(); return tuple_index_value->cast<Int32ImmPtr>()->value();
} }
// Judge whether the node is a loss, and if there are multiple outputs,
// get which output is a grad according to the tuple getitem.
// Currently, it is not supported that the sens is a tuple.
LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
MS_EXCEPTION_IF_NULL(loss_node);
FuncGraphPtr sub_graph = loss_node->func_graph();
MS_EXCEPTION_IF_NULL(sub_graph);
CNodePtr return_node = sub_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node);
if (return_node->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2";
}
AnfNodePtr pre_node = return_node->input(1);
MS_EXCEPTION_IF_NULL(pre_node);
LossNodeInfo node_info;
// return -> cast
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
pre_node = pre_cnode->input(1);
}
// return -> loss
if (pre_node == loss_node) {
node_info.has_tuple_getitem = false;
node_info.dout_index = 0;
return node_info;
}
// return -> tuple_getitem -> loss
auto cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto current_value = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(current_value);
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(current_prim);
// size of common cnode is larger than 1
if (cnode->inputs().size() < 2) {
MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2";
}
if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) {
// size of tuple_getitem cnode is 3
auto tuple_index = GetTupleGetItemIndex(cnode);
node_info.has_tuple_getitem = true;
node_info.dout_index = tuple_index;
return node_info;
}
MS_LOG(EXCEPTION) << "Invalid loss";
}
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size(); size_t node_size = node->inputs().size();
@ -958,13 +903,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
} }
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs) { const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
bool is_loss_cnode = bool is_loss_cnode =
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(), std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
[node](const std::pair<CNodePtr, CNodePtr> &element) { return element.second == node; }); [node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
MirrorOps mirror_ops = distribute_operator->mirror_ops(); MirrorOps mirror_ops = distribute_operator->mirror_ops();
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op(); VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
@ -1819,7 +1764,20 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
} }
} }
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { CNodePtr HandleDependLoss(const CNodePtr &cnode) {
// Handle return->depend->loss
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == DEPEND) {
auto depend_before = cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(depend_before);
return HandleDependLoss(depend_before);
}
return cnode;
}
LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
LossNodeInfo loss_node_info;
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
CNodePtr return_node = func_graph->get_return(); CNodePtr return_node = func_graph->get_return();
MS_EXCEPTION_IF_NULL(return_node); MS_EXCEPTION_IF_NULL(return_node);
@ -1831,9 +1789,9 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
auto pre_cnode = pre_node->cast<CNodePtr>(); auto pre_cnode = pre_node->cast<CNodePtr>();
if (pre_cnode == nullptr) { if (pre_cnode == nullptr) {
return nullptr; return loss_node_info;
} }
pre_cnode = HandleDependLoss(pre_cnode);
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// return -> cast // return -> cast
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) { if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
@ -1845,7 +1803,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
// notice: the GetNext op has not input // notice: the GetNext op has not input
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
MS_LOG(INFO) << "The loss is: " << current_prim->name(); MS_LOG(INFO) << "The loss is: " << current_prim->name();
return pre_cnode; loss_node_info.loss_node = pre_cnode;
return loss_node_info;
} }
// size of common cnode is larger than 1 // size of common cnode is larger than 1
@ -1855,36 +1814,34 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
// return -> tuple_getitem -> loss // return -> tuple_getitem -> loss
if (current_prim->name() == TUPLE_GETITEM) { if (current_prim->name() == TUPLE_GETITEM) {
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
AnfNodePtr pre_pre_node = pre_cnode->input(1); AnfNodePtr pre_pre_node = pre_cnode->input(1);
MS_EXCEPTION_IF_NULL(pre_pre_node); MS_EXCEPTION_IF_NULL(pre_pre_node);
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>(); auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>(); loss_node_info.has_tuple_getitem = true;
MS_EXCEPTION_IF_NULL(value); loss_node_info.dout_index = tuple_index;
PrimitivePtr prim = value->value()->cast<PrimitivePtr>(); loss_node_info.loss_node = pre_pre_cnode;
MS_EXCEPTION_IF_NULL(prim); return loss_node_info;
MS_LOG(DEBUG) << "The loss name is " << prim->name();
return pre_pre_cnode;
} }
// return -> make_tuple // return -> make_tuple
if (current_prim->name() == MAKE_TUPLE) { if (current_prim->name() == MAKE_TUPLE) {
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported"; MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
return nullptr; return loss_node_info;
} }
// return -> loss // return -> loss
loss_node_info.loss_node = pre_cnode;
MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
return pre_cnode; return loss_node_info;
} }
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
TensorLayouts ret; TensorLayouts ret;
auto loss_cnode = node_info.loss_node;
MS_EXCEPTION_IF_NULL(loss_cnode); MS_EXCEPTION_IF_NULL(loss_cnode);
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>();
MS_EXCEPTION_IF_NULL(node);
LossNodeInfo node_info = GetLossNodeInfo(node);
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node); MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>(); PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
@ -2086,9 +2043,9 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
return graph_set; return graph_set;
} }
void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) { void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
CNodePtr sens_node = sens_loss_pair.first; CNodePtr sens_node = sens_loss_pair.first;
CNodePtr loss_node = sens_loss_pair.second; auto loss_node = sens_loss_pair.second;
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node); auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
if (!loss_grad_layout.empty()) { if (!loss_grad_layout.empty()) {
SplitSens(sens_node, loss_grad_layout[0]); SplitSens(sens_node, loss_grad_layout[0]);
@ -2096,9 +2053,9 @@ void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
} }
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) { std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(root);
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs; std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
for (auto &node : root->nodes()) { for (auto &node : root->nodes()) {
if (!node->isa<CNode>()) { if (!node->isa<CNode>()) {
continue; continue;
@ -2140,12 +2097,12 @@ std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
} }
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1)); auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
auto loss_cnode = FindLossCNode(func_graph); auto loss_node_info = FindLossCNode(func_graph);
if (loss_cnode == nullptr) { if (loss_node_info.loss_node == nullptr) {
MS_LOG(WARNING) << "Can not find the loss cnode"; MS_LOG(WARNING) << "Can not find the loss cnode";
continue; continue;
} }
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
sens_loss_pairs.push_back(sens_loss_pair); sens_loss_pairs.push_back(sens_loss_pair);
} }
return sens_loss_pairs; return sens_loss_pairs;
@ -2157,7 +2114,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(manager);
TensorRedistribution tensor_redistribution; TensorRedistribution tensor_redistribution;
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs = GetSensLossPairs(root); std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
bool has_backward = !sens_loss_pairs.empty(); bool has_backward = !sens_loss_pairs.empty();
// split sens must before inserting the operators. // split sens must before inserting the operators.
for (auto &pair : sens_loss_pairs) { for (auto &pair : sens_loss_pairs) {
@ -2372,7 +2329,7 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> root_forward_nodes; std::vector<AnfNodePtr> root_forward_nodes;
auto loss_cnode = FindLossCNode(graph); auto loss_cnode = FindLossCNode(graph).loss_node;
if (loss_cnode == nullptr) { if (loss_cnode == nullptr) {
MS_LOG(WARNING) << "Can not find the loss cnode"; MS_LOG(WARNING) << "Can not find the loss cnode";
return root_forward_nodes; return root_forward_nodes;

View File

@ -39,6 +39,7 @@ const uint64_t kUSecondInSecond = 1000000;
struct LossNodeInfo { struct LossNodeInfo {
bool has_tuple_getitem = false; bool has_tuple_getitem = false;
int dout_index = 0; // now don't support the sens is a tuple int dout_index = 0; // now don't support the sens is a tuple
CNodePtr loss_node = nullptr;
}; };
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
@ -82,7 +83,7 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs); const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs);
// Generate and init parallel operator // Generate and init parallel operator
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,

View File

@ -32,7 +32,6 @@ from ..nn.metrics import Loss
from .. import nn from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..context import ParallelMode from ..context import ParallelMode
from ..parallel._utils import _need_to_full, _to_full_tensor
from ..parallel._cost_model_context import _set_multi_subgraphs from ..parallel._cost_model_context import _set_multi_subgraphs
from .dataset_helper import DatasetHelper, connect_network_with_dataset from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp from . import amp
@ -436,8 +435,6 @@ class Model:
# for data sink dataset_helper only iter once, other wise iter epoch_size times. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper: for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
cb_params.train_dataset_element = inputs cb_params.train_dataset_element = inputs
list_callback.step_begin(run_context) list_callback.step_begin(run_context)
outputs = self._train_network(*inputs) outputs = self._train_network(*inputs)

View File

@ -0,0 +1,78 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
import mindspore.context as context
from mindspore.common.api import _executor
from mindspore import Tensor, Parameter
import mindspore.nn as nn
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class TwoInputBpropOperator(Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
self.bp = P.TensorAdd()
def construct(self, x, y):
return self.op(x, y)
def bprop(self, x, y, out, dout):
return self.bp(5, x), self.bp(y, 8)
class ParallelFloorDivBpropNet(Cell):
def __init__(self, mul_size, test_size, strategy=None, strategy2=None):
super().__init__()
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
floordiv_np = np.full(test_size, 0.1, dtype=np.float32)
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
self.mul = TwoInputBpropOperator()
self.floor_div = P.FloorDiv()
self.bn = nn.BatchNorm1d(num_features=96)
if strategy is not None:
self.mul.op.shard(strategy2)
self.mul.bp.shard(strategy2)
self.floor_div.shard(strategy)
def construct(self, inputs, label):
x = self.mul(inputs, self.mul_weight)
x = self.floor_div(x, self.floordiv_weight)
x = self.bn(x)
return x
inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
def compile_net(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_executor.compile(train_net, inputs_, label_)
context.reset_auto_parallel_context()
def test_net():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
strategy = ((4, 1), (4, 1))
net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy)
compile_net(net)