forked from mindspore-Ecosystem/mindspore
!6502 [AutoParallel]Fix auto parallel find loss bug
Merge pull request !6502 from lichen/fix_auto_parallel_find_loss_bug
This commit is contained in:
commit
5a20b11012
|
@ -711,61 +711,6 @@ int32_t GetTupleGetItemIndex(const CNodePtr &cnode) {
|
||||||
return tuple_index_value->cast<Int32ImmPtr>()->value();
|
return tuple_index_value->cast<Int32ImmPtr>()->value();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Judge whether the node is a loss, and if there are multiple outputs,
|
|
||||||
// get which output is a grad according to the tuple getitem.
|
|
||||||
// Currently, it is not supported that the sens is a tuple.
|
|
||||||
LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
|
|
||||||
MS_EXCEPTION_IF_NULL(loss_node);
|
|
||||||
FuncGraphPtr sub_graph = loss_node->func_graph();
|
|
||||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
|
||||||
CNodePtr return_node = sub_graph->get_return();
|
|
||||||
MS_EXCEPTION_IF_NULL(return_node);
|
|
||||||
if (return_node->inputs().size() < 2) {
|
|
||||||
MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2";
|
|
||||||
}
|
|
||||||
AnfNodePtr pre_node = return_node->input(1);
|
|
||||||
MS_EXCEPTION_IF_NULL(pre_node);
|
|
||||||
|
|
||||||
LossNodeInfo node_info;
|
|
||||||
|
|
||||||
// return -> cast
|
|
||||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
|
||||||
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
||||||
if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
|
||||||
pre_node = pre_cnode->input(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// return -> loss
|
|
||||||
if (pre_node == loss_node) {
|
|
||||||
node_info.has_tuple_getitem = false;
|
|
||||||
node_info.dout_index = 0;
|
|
||||||
return node_info;
|
|
||||||
}
|
|
||||||
|
|
||||||
// return -> tuple_getitem -> loss
|
|
||||||
auto cnode = pre_node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto current_value = cnode->input(0)->cast<ValueNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(current_value);
|
|
||||||
PrimitivePtr current_prim = current_value->value()->cast<PrimitivePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(current_prim);
|
|
||||||
// size of common cnode is larger than 1
|
|
||||||
if (cnode->inputs().size() < 2) {
|
|
||||||
MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is smaller than 2";
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((current_prim->name() == TUPLE_GETITEM) && (cnode->input(1) == loss_node)) {
|
|
||||||
// size of tuple_getitem cnode is 3
|
|
||||||
auto tuple_index = GetTupleGetItemIndex(cnode);
|
|
||||||
node_info.has_tuple_getitem = true;
|
|
||||||
node_info.dout_index = tuple_index;
|
|
||||||
return node_info;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid loss";
|
|
||||||
}
|
|
||||||
|
|
||||||
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
|
void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
size_t node_size = node->inputs().size();
|
size_t node_size = node->inputs().size();
|
||||||
|
@ -958,13 +903,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||||
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs) {
|
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
|
||||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
|
||||||
bool is_loss_cnode =
|
bool is_loss_cnode =
|
||||||
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
|
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
|
||||||
[node](const std::pair<CNodePtr, CNodePtr> &element) { return element.second == node; });
|
[node](const std::pair<CNodePtr, LossNodeInfo> &element) { return element.second.loss_node == node; });
|
||||||
|
|
||||||
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
||||||
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
||||||
|
@ -1819,7 +1764,20 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
CNodePtr HandleDependLoss(const CNodePtr &cnode) {
|
||||||
|
// Handle return->depend->loss
|
||||||
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
if (prim->name() == DEPEND) {
|
||||||
|
auto depend_before = cnode->input(1)->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(depend_before);
|
||||||
|
return HandleDependLoss(depend_before);
|
||||||
|
}
|
||||||
|
return cnode;
|
||||||
|
}
|
||||||
|
|
||||||
|
LossNodeInfo FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||||
|
LossNodeInfo loss_node_info;
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
CNodePtr return_node = func_graph->get_return();
|
CNodePtr return_node = func_graph->get_return();
|
||||||
MS_EXCEPTION_IF_NULL(return_node);
|
MS_EXCEPTION_IF_NULL(return_node);
|
||||||
|
@ -1831,9 +1789,9 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||||
|
|
||||||
auto pre_cnode = pre_node->cast<CNodePtr>();
|
auto pre_cnode = pre_node->cast<CNodePtr>();
|
||||||
if (pre_cnode == nullptr) {
|
if (pre_cnode == nullptr) {
|
||||||
return nullptr;
|
return loss_node_info;
|
||||||
}
|
}
|
||||||
|
pre_cnode = HandleDependLoss(pre_cnode);
|
||||||
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
||||||
// return -> cast
|
// return -> cast
|
||||||
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
||||||
|
@ -1845,7 +1803,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||||
// notice: the GetNext op has not input
|
// notice: the GetNext op has not input
|
||||||
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) {
|
||||||
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
MS_LOG(INFO) << "The loss is: " << current_prim->name();
|
||||||
return pre_cnode;
|
loss_node_info.loss_node = pre_cnode;
|
||||||
|
return loss_node_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
// size of common cnode is larger than 1
|
// size of common cnode is larger than 1
|
||||||
|
@ -1855,36 +1814,34 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
||||||
|
|
||||||
// return -> tuple_getitem -> loss
|
// return -> tuple_getitem -> loss
|
||||||
if (current_prim->name() == TUPLE_GETITEM) {
|
if (current_prim->name() == TUPLE_GETITEM) {
|
||||||
|
auto tuple_index = GetTupleGetItemIndex(pre_cnode);
|
||||||
AnfNodePtr pre_pre_node = pre_cnode->input(1);
|
AnfNodePtr pre_pre_node = pre_cnode->input(1);
|
||||||
MS_EXCEPTION_IF_NULL(pre_pre_node);
|
MS_EXCEPTION_IF_NULL(pre_pre_node);
|
||||||
|
|
||||||
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
|
auto pre_pre_cnode = pre_pre_node->cast<CNodePtr>();
|
||||||
auto value = pre_pre_cnode->input(0)->cast<ValueNodePtr>();
|
loss_node_info.has_tuple_getitem = true;
|
||||||
MS_EXCEPTION_IF_NULL(value);
|
loss_node_info.dout_index = tuple_index;
|
||||||
PrimitivePtr prim = value->value()->cast<PrimitivePtr>();
|
loss_node_info.loss_node = pre_pre_cnode;
|
||||||
MS_EXCEPTION_IF_NULL(prim);
|
return loss_node_info;
|
||||||
MS_LOG(DEBUG) << "The loss name is " << prim->name();
|
|
||||||
return pre_pre_cnode;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// return -> make_tuple
|
// return -> make_tuple
|
||||||
if (current_prim->name() == MAKE_TUPLE) {
|
if (current_prim->name() == MAKE_TUPLE) {
|
||||||
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
|
MS_LOG(WARNING) << "The loss have make_tuple, it is not supported";
|
||||||
return nullptr;
|
return loss_node_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
// return -> loss
|
// return -> loss
|
||||||
|
loss_node_info.loss_node = pre_cnode;
|
||||||
MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
|
MS_LOG(DEBUG) << "The loss name is " << current_prim->name();
|
||||||
return pre_cnode;
|
return loss_node_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
|
TensorLayouts GetLossNodeGradOutputLayout(const LossNodeInfo &node_info) {
|
||||||
TensorLayouts ret;
|
TensorLayouts ret;
|
||||||
|
auto loss_cnode = node_info.loss_node;
|
||||||
MS_EXCEPTION_IF_NULL(loss_cnode);
|
MS_EXCEPTION_IF_NULL(loss_cnode);
|
||||||
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
|
||||||
|
|
||||||
LossNodeInfo node_info = GetLossNodeInfo(node);
|
|
||||||
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
|
ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast<ValueNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||||
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||||
|
@ -2086,9 +2043,9 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
|
||||||
return graph_set;
|
return graph_set;
|
||||||
}
|
}
|
||||||
|
|
||||||
void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
|
void StepSplitSens(const std::pair<CNodePtr, LossNodeInfo> &sens_loss_pair) {
|
||||||
CNodePtr sens_node = sens_loss_pair.first;
|
CNodePtr sens_node = sens_loss_pair.first;
|
||||||
CNodePtr loss_node = sens_loss_pair.second;
|
auto loss_node = sens_loss_pair.second;
|
||||||
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
|
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
|
||||||
if (!loss_grad_layout.empty()) {
|
if (!loss_grad_layout.empty()) {
|
||||||
SplitSens(sens_node, loss_grad_layout[0]);
|
SplitSens(sens_node, loss_grad_layout[0]);
|
||||||
|
@ -2096,9 +2053,9 @@ void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
||||||
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) {
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphPtr &root) {
|
||||||
MS_EXCEPTION_IF_NULL(root);
|
MS_EXCEPTION_IF_NULL(root);
|
||||||
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs;
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs;
|
||||||
for (auto &node : root->nodes()) {
|
for (auto &node : root->nodes()) {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -2140,12 +2097,12 @@ std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &
|
||||||
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
|
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
|
||||||
}
|
}
|
||||||
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
|
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
|
||||||
auto loss_cnode = FindLossCNode(func_graph);
|
auto loss_node_info = FindLossCNode(func_graph);
|
||||||
if (loss_cnode == nullptr) {
|
if (loss_node_info.loss_node == nullptr) {
|
||||||
MS_LOG(WARNING) << "Can not find the loss cnode";
|
MS_LOG(WARNING) << "Can not find the loss cnode";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode);
|
std::pair<CNodePtr, LossNodeInfo> sens_loss_pair = std::make_pair(sens_cnode, loss_node_info);
|
||||||
sens_loss_pairs.push_back(sens_loss_pair);
|
sens_loss_pairs.push_back(sens_loss_pair);
|
||||||
}
|
}
|
||||||
return sens_loss_pairs;
|
return sens_loss_pairs;
|
||||||
|
@ -2157,7 +2114,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
||||||
MS_EXCEPTION_IF_NULL(manager);
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
TensorRedistribution tensor_redistribution;
|
TensorRedistribution tensor_redistribution;
|
||||||
|
|
||||||
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs = GetSensLossPairs(root);
|
std::vector<std::pair<CNodePtr, LossNodeInfo>> sens_loss_pairs = GetSensLossPairs(root);
|
||||||
bool has_backward = !sens_loss_pairs.empty();
|
bool has_backward = !sens_loss_pairs.empty();
|
||||||
// split sens must before inserting the operators.
|
// split sens must before inserting the operators.
|
||||||
for (auto &pair : sens_loss_pairs) {
|
for (auto &pair : sens_loss_pairs) {
|
||||||
|
@ -2372,7 +2329,7 @@ std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root) {
|
||||||
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
|
std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
std::vector<AnfNodePtr> root_forward_nodes;
|
std::vector<AnfNodePtr> root_forward_nodes;
|
||||||
auto loss_cnode = FindLossCNode(graph);
|
auto loss_cnode = FindLossCNode(graph).loss_node;
|
||||||
if (loss_cnode == nullptr) {
|
if (loss_cnode == nullptr) {
|
||||||
MS_LOG(WARNING) << "Can not find the loss cnode";
|
MS_LOG(WARNING) << "Can not find the loss cnode";
|
||||||
return root_forward_nodes;
|
return root_forward_nodes;
|
||||||
|
|
|
@ -39,6 +39,7 @@ const uint64_t kUSecondInSecond = 1000000;
|
||||||
struct LossNodeInfo {
|
struct LossNodeInfo {
|
||||||
bool has_tuple_getitem = false;
|
bool has_tuple_getitem = false;
|
||||||
int dout_index = 0; // now don't support the sens is a tuple
|
int dout_index = 0; // now don't support the sens is a tuple
|
||||||
|
CNodePtr loss_node = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
|
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name);
|
||||||
|
@ -82,7 +83,7 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
|
||||||
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
|
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
|
||||||
|
|
||||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||||
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs);
|
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs);
|
||||||
|
|
||||||
// Generate and init parallel operator
|
// Generate and init parallel operator
|
||||||
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
|
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
|
||||||
|
|
|
@ -32,7 +32,6 @@ from ..nn.metrics import Loss
|
||||||
from .. import nn
|
from .. import nn
|
||||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||||
from ..context import ParallelMode
|
from ..context import ParallelMode
|
||||||
from ..parallel._utils import _need_to_full, _to_full_tensor
|
|
||||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||||
from . import amp
|
from . import amp
|
||||||
|
@ -436,8 +435,6 @@ class Model:
|
||||||
|
|
||||||
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
|
||||||
for inputs in dataset_helper:
|
for inputs in dataset_helper:
|
||||||
if _need_to_full() and context.get_context("device_target") == "GPU":
|
|
||||||
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
|
|
||||||
cb_params.train_dataset_element = inputs
|
cb_params.train_dataset_element = inputs
|
||||||
list_callback.step_begin(run_context)
|
list_callback.step_begin(run_context)
|
||||||
outputs = self._train_network(*inputs)
|
outputs = self._train_network(*inputs)
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.context as context
|
||||||
|
from mindspore.common.api import _executor
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class TwoInputBpropOperator(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.op = P.Mul()
|
||||||
|
self.bp = P.TensorAdd()
|
||||||
|
|
||||||
|
def construct(self, x, y):
|
||||||
|
return self.op(x, y)
|
||||||
|
|
||||||
|
def bprop(self, x, y, out, dout):
|
||||||
|
return self.bp(5, x), self.bp(y, 8)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelFloorDivBpropNet(Cell):
|
||||||
|
def __init__(self, mul_size, test_size, strategy=None, strategy2=None):
|
||||||
|
super().__init__()
|
||||||
|
mul_np = np.full(mul_size, 0.5, dtype=np.float32)
|
||||||
|
floordiv_np = np.full(test_size, 0.1, dtype=np.float32)
|
||||||
|
self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
|
||||||
|
self.floordiv_weight = Parameter(Tensor(floordiv_np), name="floordiv_weight")
|
||||||
|
self.mul = TwoInputBpropOperator()
|
||||||
|
self.floor_div = P.FloorDiv()
|
||||||
|
self.bn = nn.BatchNorm1d(num_features=96)
|
||||||
|
if strategy is not None:
|
||||||
|
self.mul.op.shard(strategy2)
|
||||||
|
self.mul.bp.shard(strategy2)
|
||||||
|
self.floor_div.shard(strategy)
|
||||||
|
|
||||||
|
def construct(self, inputs, label):
|
||||||
|
x = self.mul(inputs, self.mul_weight)
|
||||||
|
x = self.floor_div(x, self.floordiv_weight)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
inputs_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
|
||||||
|
label_ = Tensor(np.random.randn(128, 96).astype(np.float32), dtype=ms.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_net(net):
|
||||||
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
train_net = TrainOneStepCell(net, optimizer)
|
||||||
|
train_net.set_auto_parallel()
|
||||||
|
train_net.set_train()
|
||||||
|
_executor.compile(train_net, inputs_, label_)
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4, global_rank=0)
|
||||||
|
strategy = ((4, 1), (4, 1))
|
||||||
|
net = ParallelFloorDivBpropNet(mul_size=(128, 96), test_size=(128, 96), strategy=strategy, strategy2=strategy)
|
||||||
|
compile_net(net)
|
Loading…
Reference in New Issue