forked from OSSInnovation/mindspore
!6502 [AutoParallel]Fix auto parallel find loss bug
Merge pull request !6502 from lichen/fix_auto_parallel_find_loss_bug
This commit is contained in:
commit
5a20b11012
|
@ -711,61 +711,6 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
|||
return tuple_index_value->cast<Int32ImmPtr>()->value();
|
||||
}
|
||||
|
||||
// Judge whether the node is a loss, and if there are multiple outputs,
|
||||
// get which output is a grad according to the tuple getitem.
|
||||
// Currently, it is not supported that the sens is a tuple.
|
||||
LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
|
||||
MS_EXCEPTION_IF_NULL(loss_node);
|
||||
FuncGraphPtr sub_graph = loss_node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
CNodePtr return_node = sub_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
if (return_node->inputs().size() < 2) {
|
||||
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2";
|
||||
}
|
||||
AnfNodePtr pre_node = return_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(pre_node);
|
||||
|
||||
LossNodeInfo node_info;
|
||||
|
||||
// return -> cast
|
||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
||||
pre_node = pre_cnode->input(1);
|
||||
}
|
||||
|
||||
// return -> loss
|
||||
if (pre_node == loss_node) {
|
||||
node_info.has_tuple_getitem = false;
|
||||
node_info.dout_index = 0;
|
||||
return node_info;
|
||||
}
|
||||
|
||||
// return -> tuple_getitem -> loss
|
||||
auto cnode = pre_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto current_value = cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(current_value);
|
||||
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(current_prim);
|
||||
// size of common cnode is larger than 1
|
||||
if (cnode->inputs().size() < 2) {
|
||||
MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2";
|
||||
}
|
||||
|
||||
if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) {
|
||||
// size of tuple_getitem cnode is 3
|
||||
auto tuple_index = GetTupleGetItemIndex(cnode);
|
||||
node_info.has_tuple_getitem = true;
|
||||
node_info.dout_index = tuple_index;
|
||||
return node_info;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Invalid loss";
|
||||
}
|
||||
|
||||
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t node_size = node->inputs().size();
|
||||
|
@ -958,13 +903,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
}
|
||||
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs) {
|
||||
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
bool is_loss_cnode =
|
||||
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
|
||||
[node](const std::pair<CNodePtr, CNodePtr> &element) { return element.second == node; });
|
||||
[node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
|
||||
|
||||
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
||||
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
||||
|
@ -1819,7 +1764,20 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
}
|
||||
}
|
||||
|
||||
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||
CNodePtr HandleDependLoss(const CNodePtr &cnode) {
|
||||
// Handle return->depend->loss
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
if (prim->name() == DEPEND) {
|
||||
auto depend_before = cnode->input(1)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(depend_before);
|
||||
return HandleDependLoss(depend_before);
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
|
||||
LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||
LossNodeInfo loss_node_info;
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
CNodePtr return_node = func_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(return_node);
|
||||
|
@ -1831,9 +1789,9 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
|
||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||
if (pre_cnode == nullptr) {
|
||||
return nullptr;
|
||||
return loss_node_info;
|
||||
}
|
||||
|
||||
pre_cnode = HandleDependLoss(pre_cnode);
|
||||
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||
// return -> cast
|
||||
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
||||
|
@ -1845,7 +1803,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
// notice: the GetNext op has not input
|
||||
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
||||
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
||||
return pre_cnode;
|
||||
loss_node_info.loss_node = pre_cnode;
|
||||
return loss_node_info;
|
||||
}
|
||||
|
||||
// size of common cnode is larger than 1
|
||||
|
@ -1855,36 +1814,34 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
|
||||
// return -> tuple_getitem -> loss
|
||||
if (current_prim->name() == TUPLE_GETITEM) {
|
||||
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
|
||||
AnfNodePtr pre_pre_node = pre_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(pre_pre_node);
|
||||
|
||||
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
|
||||
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
PrimitivePtr prim = value->value()->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
MS_LOG(DEBUG) << "The loss name is " << prim->name();
|
||||
return pre_pre_cnode;
|
||||
loss_node_info.has_tuple_getitem = true;
|
||||
loss_node_info.dout_index = tuple_index;
|
||||
loss_node_info.loss_node = pre_pre_cnode;
|
||||
return loss_node_info;
|
||||
}
|
||||
|
||||
// return -> make_tuple
|
||||
if (current_prim->name() == MAKE_TUPLE) {
|
||||
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
|
||||
return nullptr;
|
||||
return loss_node_info;
|
||||
}
|
||||
|
||||
// return -> loss
|
||||
loss_node_info.loss_node = pre_cnode;
|
||||
MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
|
||||
return pre_cnode;
|
||||
return loss_node_info;
|
||||
}
|
||||
|
||||
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
|
||||
TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
|
||||
TensorLayouts ret;
|
||||
auto loss_cnode = node_info.loss_node;
|
||||
MS_EXCEPTION_IF_NULL(loss_cnode);
|
||||
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
LossNodeInfo node_info = GetLossNodeInfo(node);
|
||||
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
|
@ -2086,9 +2043,9 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
|
|||
return graph_set;
|
||||
}
|
||||
|
||||
void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
|
||||
void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
|
||||
CNodePtr sens_node = sens_loss_pair.first;
|
||||
CNodePtr loss_node = sens_loss_pair.second;
|
||||
auto loss_node = sens_loss_pair.second;
|
||||
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
|
||||
if (!loss_grad_layout.empty()) {
|
||||
SplitSens(sens_node, loss_grad_layout[0]);
|
||||
|
@ -2096,9 +2053,9 @@ void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
|
|||
}
|
||||
|
||||
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) {
|
||||
std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs;
|
||||
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
|
||||
for (auto &node : root->nodes()) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
|
@ -2140,12 +2097,12 @@ std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &
|
|||
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
|
||||
}
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
|
||||
auto loss_cnode = FindLossCNode(func_graph);
|
||||
if (loss_cnode == nullptr) {
|
||||
auto loss_node_info = FindLossCNode(func_graph);
|
||||
if (loss_node_info.loss_node == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find the loss cnode";
|
||||
continue;
|
||||
}
|
||||
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode);
|
||||
std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
|
||||
sens_loss_pairs.push_back(sens_loss_pair);
|
||||
}
|
||||
return sens_loss_pairs;
|
||||
|
@ -2157,7 +2114,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
TensorRedistribution tensor_redistribution;
|
||||
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs = GetSensLossPairs(root);
|
||||
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
|
||||
bool has_backward = !sens_loss_pairs.empty();
|
||||
// split sens must before inserting the operators.
|
||||
for (auto &pair : sens_loss_pairs) {
|
||||
|
@ -2372,7 +2329,7 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
|
|||
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
std::vector<AnfNodePtr> root_forward_nodes;
|
||||
auto loss_cnode = FindLossCNode(graph);
|
||||
auto loss_cnode = FindLossCNode(graph).loss_node;
|
||||
if (loss_cnode == nullptr) {
|
||||
MS_LOG(WARNING) << "Can not find the loss cnode";
|
||||
return root_forward_nodes;
|
||||
|
|
|
@ -39,6 +39,7 @@ const uint64_t kUSecondInSecond = 1000000;
|
|||
struct LossNodeInfo {
|
||||
bool has_tuple_getitem = false;
|
||||
int dout_index = 0; // now don't support the sens is a tuple
|
||||
CNodePtr loss_node = nullptr;
|
||||
};
|
||||
|
||||
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
|
||||
|
@ -82,7 +83,7 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
|
|||
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
|
||||
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs);
|
||||
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs);
|
||||
|
||||
// Generate and init parallel operator
|
||||
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
|
||||
|
|
|
@ -32,7 +32,6 @@ from ..nn.metrics import Loss
|
|||
from .. import nn
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from ..context import ParallelMode
|
||||
from ..parallel._utils import _need_to_full, _to_full_tensor
|
||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
from . import amp
|
||||
|
@ -436,8 +435,6 @@ class Model:
|
|||
|
||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||
for inputs in dataset_helper:
|
||||
if _need_to_full() and context.get_context("device_target") == "GPU":
|
||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
||||
cb_params.train_dataset_element = inputs
|
||||
list_callback.step_begin(run_context)
|
||||
outputs = self._train_network(*inputs)
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.context as context
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore import Tensor, Parameter
|
||||
import mindspore.nn as nn
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class TwoInputBpropOperator(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.Mul()
|
||||
self.bp = P.TensorAdd()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.op(x, y)
|
||||
|
||||
def bprop(self, x, y, out, dout):
|
||||
return self.bp(5, x), self.bp(y, 8)
|
||||
|
||||
|
||||
class ParallelFloorDivBpropNet(Cell):
|
||||
def __init__(self, mul_size, test_size, strategy=None, strategy2=None):
|
||||
super().__init__()
|
||||
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
|
||||
floordiv_np = np.full(test_size, 0.1, dtype=np.float32)
|
||||
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
|
||||
self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
|
||||
self.mul = TwoInputBpropOperator()
|
||||
self.floor_div = P.FloorDiv()
|
||||
self.bn = nn.BatchNorm1d(num_features=96)
|
||||
if strategy is not None:
|
||||
self.mul.op.shard(strategy2)
|
||||
self.mul.bp.shard(strategy2)
|
||||
self.floor_div.shard(strategy)
|
||||
|
||||
def construct(self, inputs, label):
|
||||
x = self.mul(inputs, self.mul_weight)
|
||||
x = self.floor_div(x, self.floordiv_weight)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
|
||||
label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net):
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_executor.compile(train_net, inputs_, label_)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_net():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||
strategy = ((4, 1), (4, 1))
|
||||
net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy)
|
||||
compile_net(net)
|
Loading…
Reference in New Issue