forked from mindspore-Ecosystem/mindspore
support forward graph
This commit is contained in:
parent
001912237e
commit
36a62576e8
|
@ -345,7 +345,6 @@ bool FindCommunicationOp(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -903,9 +902,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node) {
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs) {
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
bool is_loss_cnode =
|
||||
std::any_of(sens_loss_pairs.begin(), sens_loss_pairs.end(),
|
||||
[node](const std::pair<CNodePtr, CNodePtr> &element) { return element.second == node; });
|
||||
|
||||
MirrorOps mirror_ops = distribute_operator->mirror_ops();
|
||||
VirtualDivOp virtual_div_op = distribute_operator->virtual_div_op();
|
||||
// insert mirror op
|
||||
|
@ -914,7 +919,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
|
|||
InsertMirrorOps(mirror_ops, node);
|
||||
}
|
||||
// insert virtual div op
|
||||
if (!virtual_div_op.empty() && is_loss_node) {
|
||||
if (!virtual_div_op.empty() && is_loss_cnode) {
|
||||
MS_LOG(INFO) << "insert virtual div op for " << distribute_operator->name();
|
||||
InsertVirtualDivOp(virtual_div_op, node);
|
||||
}
|
||||
|
@ -986,10 +991,6 @@ StrategyPtr ExtractStrategy(std::unordered_map<std::string, ValuePtr> attrs) {
|
|||
Dimensions dim;
|
||||
if (elements[index]->isa<ValueSequeue>()) {
|
||||
ValueTuplePtr value_tuple = elements[index]->cast<ValueTuplePtr>();
|
||||
if (value_tuple == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Failure:value_tuple is nullptr";
|
||||
}
|
||||
|
||||
std::vector<ValuePtr> value_vector = value_tuple->value();
|
||||
(void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim),
|
||||
[](const ValuePtr &value) { return static_cast<int32_t>(GetValue<int>(value)); });
|
||||
|
@ -1013,7 +1014,6 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
|
|||
BaseShapePtr base_shape_ptr = node->Shape();
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsValueNode<Primitive>(cnode->input(0))) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -1190,7 +1190,7 @@ std::pair<AnfNodePtr, int> FindSubGraph(const FuncGraphPtr &graph, const AnfNode
|
|||
continue;
|
||||
}
|
||||
CNodePtr graph_cnode_inp0 = graph_cnode->input(0)->cast<CNodePtr>();
|
||||
if ((graph_cnode_inp0 == nullptr) || !IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) {
|
||||
if (!IsValueNode<FuncGraph>(graph_cnode_inp0->input(1))) {
|
||||
continue;
|
||||
}
|
||||
FuncGraphPtr graph_sub = GetValueNode<FuncGraphPtr>(graph_cnode_inp0->input(1));
|
||||
|
@ -1692,14 +1692,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|||
return pre_cnode;
|
||||
}
|
||||
|
||||
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
|
||||
TensorLayouts ret;
|
||||
if (!IsValueNode<FuncGraph>(cnode->input(1))) {
|
||||
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
|
||||
}
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
|
||||
auto loss_cnode = FindLossCNode(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(loss_cnode);
|
||||
AnfNodePtr node = loss_cnode->cast<AnfNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -1735,16 +1729,16 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
|
||||
void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) {
|
||||
MS_EXCEPTION_IF_NULL(grad_sens_node);
|
||||
|
||||
auto cnode = grad_sens_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtr sens_tensor_node = cnode->input(1);
|
||||
if (grad_sens_node->size() <= 1) {
|
||||
MS_LOG(EXCEPTION) << "The size of grad sens node is smaller than 2";
|
||||
}
|
||||
AnfNodePtr sens_tensor_node = grad_sens_node->input(1);
|
||||
MS_EXCEPTION_IF_NULL(sens_tensor_node);
|
||||
Shapes sens_shapes = GetNodeShape(sens_tensor_node);
|
||||
if (sens_shapes.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "SplitSens: GetNodeShape for sens_tensor_node, output size is not 1";
|
||||
MS_LOG(EXCEPTION) << "GetNodeShape for sens_tensor_node, output size is not 1";
|
||||
}
|
||||
// If the shape of sens tensor is [] or [1], no need to split it.
|
||||
Shape sens_shape = sens_shapes[0];
|
||||
|
@ -1780,14 +1774,14 @@ void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_l
|
|||
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
|
||||
return;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "SplitSens: the type of sens node is not Tensor or Parameter, it is unsupported now.";
|
||||
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
|
||||
}
|
||||
|
||||
// Use _GetTensorSlice operator to split the sens tensor
|
||||
FuncGraphPtr func_graph = cnode->func_graph(); // only cnode can get the graph
|
||||
FuncGraphPtr func_graph = grad_sens_node->func_graph(); // only cnode can get the graph
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
Operator op = CreateGetTensorSliceOp(loss_grad_layout);
|
||||
InsertGetTensorSliceOp(op, cnode, func_graph, 1, SPLIT_SENS);
|
||||
InsertGetTensorSliceOp(op, grad_sens_node, func_graph, 1, SPLIT_SENS);
|
||||
}
|
||||
|
||||
void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
|
||||
|
@ -1853,7 +1847,6 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
|
|||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if ((cnode->size() < 2) || !IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1870,55 +1863,12 @@ std::set<FuncGraphPtr> FindForwardGraphByRootNodes(const AnfNodeSet &root_all_no
|
|||
return graph_set;
|
||||
}
|
||||
|
||||
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
||||
void StepSplitSens(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// cnode(sens)-->cnode(tuple_getitem)
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtr expect_tuple_getitem = cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
|
||||
if (!expect_tuple_getitem->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode);
|
||||
if (!IsValueNode<Primitive>(expect_tuple_getitem_cnode->input(0))) {
|
||||
return;
|
||||
}
|
||||
auto expect_tuple_getitem_prim = GetValueNode<PrimitivePtr>(expect_tuple_getitem_cnode->input(0));
|
||||
if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) {
|
||||
return;
|
||||
}
|
||||
|
||||
// cnode(sens)-->cnode(tuple_getitem)-->cnode
|
||||
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(expect_anonymous);
|
||||
if (!expect_anonymous->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
||||
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(expect_anonymous_cnode);
|
||||
AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(expect_j);
|
||||
if (!expect_j->isa<CNode>()) {
|
||||
return;
|
||||
}
|
||||
auto expect_j_cnode = expect_j->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(expect_j_cnode);
|
||||
if (!IsValueNode<Primitive>(expect_j_cnode->input(0))) {
|
||||
return;
|
||||
}
|
||||
auto expect_j_prim = GetValueNode<PrimitivePtr>(expect_j_cnode->input(0));
|
||||
if (expect_j_prim->name() == J) {
|
||||
auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode);
|
||||
if (!loss_grad_layout.empty()) {
|
||||
SplitSens(node, loss_grad_layout[0]);
|
||||
}
|
||||
void StepSplitSens(const std::pair<CNodePtr, CNodePtr> &sens_loss_pair) {
|
||||
CNodePtr sens_node = sens_loss_pair.first;
|
||||
CNodePtr loss_node = sens_loss_pair.second;
|
||||
auto loss_grad_layout = GetLossNodeGradOutputLayout(loss_node);
|
||||
if (!loss_grad_layout.empty()) {
|
||||
SplitSens(sens_node, loss_grad_layout[0]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1937,26 +1887,77 @@ std::vector<CNodePtr> FindLossCNodeFromRoot(const FuncGraphPtr &root) {
|
|||
return loss_node;
|
||||
}
|
||||
|
||||
// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> GetSensLossPairs(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs;
|
||||
for (auto &node : root->nodes()) {
|
||||
if (!node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// cnode(sens)-->cnode(tuple_getitem)
|
||||
auto sens_cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtr expect_tuple_getitem = sens_cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(expect_tuple_getitem);
|
||||
if (!expect_tuple_getitem->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast<CNodePtr>();
|
||||
if (!IsSomePrimitive(expect_tuple_getitem_cnode, TUPLE_GETITEM)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// cnode(sens)-->cnode(tuple_getitem)-->cnode
|
||||
AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(expect_anonymous);
|
||||
if (!expect_anonymous->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J)
|
||||
auto expect_anonymous_cnode = expect_anonymous->cast<CNodePtr>();
|
||||
AnfNodePtr expect_j = expect_anonymous_cnode->input(0);
|
||||
MS_EXCEPTION_IF_NULL(expect_j);
|
||||
if (!expect_j->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
auto expect_j_cnode = expect_j->cast<CNodePtr>();
|
||||
if (!IsSomePrimitive(expect_j_cnode, J)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsValueNode<FuncGraph>(expect_j_cnode->input(1))) {
|
||||
MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph.";
|
||||
}
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(expect_j_cnode->input(1));
|
||||
auto loss_cnode = FindLossCNode(func_graph);
|
||||
std::pair<CNodePtr, CNodePtr> sens_loss_pair = std::make_pair(sens_cnode, loss_cnode);
|
||||
sens_loss_pairs.push_back(sens_loss_pair);
|
||||
}
|
||||
return sens_loss_pairs;
|
||||
}
|
||||
|
||||
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
|
||||
const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
TensorRedistribution tensor_redistribution;
|
||||
AnfNodePtr grad_sens_node = nullptr;
|
||||
|
||||
std::vector<CNodePtr> loss_cnode = FindLossCNodeFromRoot(root);
|
||||
std::vector<std::pair<CNodePtr, CNodePtr>> sens_loss_pairs = GetSensLossPairs(root);
|
||||
bool has_backward = !sens_loss_pairs.empty();
|
||||
// split sens must before inserting the operators.
|
||||
for (auto &node : all_nodes) {
|
||||
for (auto &pair : sens_loss_pairs) {
|
||||
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
|
||||
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
|
||||
StepSplitSens(node);
|
||||
StepSplitSens(pair);
|
||||
}
|
||||
|
||||
for (auto &node : all_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1965,11 +1966,6 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
continue;
|
||||
}
|
||||
|
||||
bool is_loss_cnode = false;
|
||||
auto iter = std::find(loss_cnode.begin(), loss_cnode.end(), cnode);
|
||||
if (iter != loss_cnode.end()) {
|
||||
is_loss_cnode = true;
|
||||
}
|
||||
// insert forward ops
|
||||
InsertForwardOps(distribute_operator, cnode);
|
||||
|
||||
|
@ -1977,7 +1973,9 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
|
||||
|
||||
// insert backward ops
|
||||
BackwardCommunication(distribute_operator, cnode, is_loss_cnode);
|
||||
if (has_backward) {
|
||||
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
|
||||
}
|
||||
|
||||
// StepReplace
|
||||
StepReplace(distribute_operator, cnode);
|
||||
|
@ -2099,7 +2097,6 @@ void SetForwardFlag(const std::vector<AnfNodePtr> &all_nodes) {
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -2117,7 +2114,6 @@ void SetForwardFlag(const AnfNodeSet &all_nodes) {
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
continue;
|
||||
}
|
||||
|
@ -2146,7 +2142,6 @@ std::vector<AnfNodePtr> FindRootForwardCNode(const FuncGraphPtr &graph, const An
|
|||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto root_node_id = node->UniqueIdThroughCopy();
|
||||
if (loss_cnode_id == root_node_id) {
|
||||
root_forward_nodes = DeepLinkedGraphSearch(cnode);
|
||||
|
|
|
@ -82,7 +82,8 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
|
|||
|
||||
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
|
||||
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node);
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
const std::vector<std::pair<CNodePtr, CNodePtr>> &sens_loss_pairs);
|
||||
|
||||
// Generate and init parallel operator
|
||||
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, mul_weight, strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().set_strategy(strategy1)
|
||||
self.neg = P.Neg().set_strategy(strategy2)
|
||||
self.mul_weight = Parameter(mul_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.mul_weight)
|
||||
out = self.neg(out)
|
||||
return out, b
|
||||
|
||||
|
||||
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile(net):
|
||||
_executor.compile(net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_forward_graph_data_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((16, 1, 1), (16, 1, 1))
|
||||
strategy2 = ((16, 1, 1), )
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_forward_graph_model_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((1, 1, 16), (1, 1, 16))
|
||||
strategy2 = ((1, 1, 16), )
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_forward_graph_hybrid_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 4), (2, 2, 4))
|
||||
strategy2 = ((2, 2, 4), )
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_forward_graph_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
|
||||
net = Net(_w1)
|
||||
compile(net)
|
||||
|
||||
|
||||
def test_forward_graph_repeat_calc():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 4), (2, 2, 4))
|
||||
strategy2 = ((1, 2, 2), )
|
||||
net = Net(_w1, strategy1, strategy2)
|
||||
compile(net)
|
||||
|
Loading…
Reference in New Issue