!40455 [AutoParallel]optimizer Pipeline parallel

Merge pull request !40455 from lichen/opt_pipeline_parallel
This commit is contained in:
i-robot 2022-08-19 01:15:31 +00:00 committed by Gitee
commit d222d7393c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 176 additions and 193 deletions

View File

@ -116,13 +116,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
prim::kPrimMirrorMiniStep);
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(),
"micro_step_allgather_replace", prim::kPrimMicroStepAllGather);
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =
@ -216,18 +213,16 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Virtual Dataset
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
// Virtual Dataset
// Virtual Output
virtual_output_eliminate_ =
MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
// PipelineSplit
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
virtual_accu_grad_ =
MakeSubstitution(std::make_shared<VirtualAccuGradEliminater>(), "virtual_accu_grad", prim::kPrimVirtualAccuGrad);
virtual_assign_add_ =
MakeSubstitution(std::make_shared<VirtualAssignAddEliminater>(), "virtual_assign_add", prim::kPrimVirtualAssignAdd);
mirror_micro_step_ =
MakeSubstitution(std::make_shared<MirrorMicroStepEliminater>(), "mirror_micro_step", prim::kPrimMirrorMicroStep);
parallel_virtual_node_ =
MakeSubstitution(std::make_shared<ParallelVirtualNodeEliminater>(), "parallel_virtual_node",
{prim::kPrimVirtualAssignAdd, prim::kPrimVirtualPipelineEnd, prim::kPrimVirtualAccuGrad,
prim::kPrimMirrorMicroStep, prim::kPrimVirtualAdd, prim::kPrimMirrorMiniStep});
// Convert
print_tuple_wrapper_ =

View File

@ -60,8 +60,6 @@ class OptimizeIRPassLib {
SubstitutionPtr reset_defer_inline_;
SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_;
SubstitutionPtr mirror_mini_step_elim_;
SubstitutionPtr virtual_add_elim_;
SubstitutionPtr mini_step_allgather_replace_;
SubstitutionPtr micro_step_allgather_replace_;
SubstitutionPtr real_op_eliminate_;
@ -128,10 +126,7 @@ class OptimizeIRPassLib {
SubstitutionPtr virtual_output_eliminate_;
// PipelineSplit
SubstitutionPtr receive_eliminate_;
SubstitutionPtr virtual_accu_grad_;
SubstitutionPtr virtual_assign_add_;
SubstitutionPtr mirror_micro_step_;
SubstitutionPtr parallel_virtual_node_;
// Convert
SubstitutionPtr print_tuple_wrapper_;

View File

@ -101,62 +101,35 @@ class VirtualDatasetEliminater : public AnfVisitor {
return node->func_graph()->NewCNode(args);
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimVirtualOutput, X} -> X
class VirtualOutputEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->inputs().size() <= 1) {
if (cnode == nullptr) {
return nullptr;
}
return cnode->input(1);
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimReceive, X} -> prim::kPrimReceive
class ReceiveEliminater : public AnfVisitor {
// {ParallelVirtualNode, X, Y...} -> X
class ParallelVirtualNodeEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (cnode->inputs().size() == 1) {
if (cnode == nullptr) {
return nullptr;
}
std::vector<AnfNodePtr> args = {cnode->input(0)};
return node->func_graph()->NewCNode(args);
auto input = cnode->input(1);
if (input->isa<CNode>()) {
auto input_cnode = input->cast<CNodePtr>();
input_cnode->set_primal_attrs(cnode->primal_attrs());
}
return cnode->input(1);
}
void Visit(const AnfNodePtr &) override {}
};
class VirtualAssignAddEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAssignAdd) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
private:
AnfNodePtr x_{nullptr};
};
class LinSpaceValue : public AnfVisitor {
@ -197,44 +170,6 @@ class LinSpaceValue : public AnfVisitor {
}
};
class VirtualAccuGradEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAccuGrad) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
private:
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMirrorMicroStep, X, Z} -> X
class MirrorMicroStepEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimSameTypeShape, X, Y} -> X
class SameEliminater : public AnfVisitor {
public:
@ -273,44 +208,6 @@ class CheckBpropEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMirrorMiniStep, X, Z} -> X
class MirrorMiniStepEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimVirtualAdd, X, Z} -> X
class VirtualAddEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
class MiniStepAllGatherPass : public AnfVisitor {
public:
@ -346,8 +243,6 @@ class MiniStepAllGatherPass : public AnfVisitor {
CNodePtr new_node = func_graph->NewCNode(node_input);
return new_node;
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
@ -385,8 +280,6 @@ class MicroStepAllGatherPass : public AnfVisitor {
CNodePtr new_node = func_graph->NewCNode(node_input);
return new_node;
}
void Visit(const AnfNodePtr &) override {}
};
// Reset defer_inline flag

View File

@ -33,27 +33,6 @@
namespace mindspore {
namespace parallel {
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits,
prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape};
static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) {
return true;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (IsInParallelBlackList(prim)) {
return true;
}
for (auto prim_node = END_NODE_BLACK_LIST.cbegin(); prim_node != END_NODE_BLACK_LIST.cend(); ++prim_node) {
if (IsPrimitiveCNode(cnode, *prim_node)) {
return true;
}
}
return false;
}
AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
auto pre_node = cnode->input(1);
size_t depth = 0;
@ -557,24 +536,19 @@ void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root)
}
}
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> node_queue = {node};
while (!node_queue.empty()) {
auto cur_node = (*node_queue.cbegin())->cast<CNodePtr>();
if (!cur_node) {
(void)node_queue.erase(node_queue.cbegin());
continue;
}
(void)node_queue.erase(node_queue.cbegin());
if (!IsInEndNodeBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) {
MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
return cur_node;
}
(void)node_queue.insert(node_queue.cend(), cur_node->inputs().cbegin() + 1, cur_node->inputs().cend());
}
MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
void InsertVirtualPipelineEndNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, size_t index) {
auto pre_cnode = cnode->input(index)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
OperatorAttrs attrs_;
auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node");
auto value_node = NewValueNode(op);
auto virtual_end = graph->NewCNode({value_node, pre_cnode});
virtual_end->set_abstract(pre_cnode->abstract());
virtual_end->AddPrimalAttr(PIPELINE_END, pre_cnode->GetPrimalAttr(MICRO));
virtual_end->AddPrimalAttr(MICRO, pre_cnode->GetPrimalAttr(MICRO));
manager->SetEdge(cnode, SizeToInt(index), virtual_end);
}
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
@ -593,7 +567,8 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
}
auto prim = GetCNodePrimitive(node);
if (prim && prim->HasAttr(PIPELINE_END)) {
for (auto &temp_node : cnode->inputs()) {
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
auto temp_node = cnode->input(i);
if (!temp_node->isa<CNode>()) {
continue;
}
@ -601,18 +576,7 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
continue;
}
auto end_node = GetPreNode(temp_node);
MS_EXCEPTION_IF_NULL(end_node);
auto end_cnode = end_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(end_cnode);
auto end_prim = GetCNodePrimitive(end_node);
OperatorAttrs attrs_;
auto op = CreateOpInstance(attrs_, end_prim->name(), "");
auto value_node = NewValueNode(op);
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
(void)new_prim->SetAttrs(end_prim->attrs());
manager->SetEdge(end_node, 0, value_node);
end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
InsertVirtualPipelineEndNode(cnode, manager, i);
}
}
}

View File

@ -58,7 +58,6 @@ void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth);
void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root);
void BroadCastNeedGrad(const AnfNodePtr &node, NodeUsersMap *node_user_map, const FuncGraphPtr &root);
AnfNodePtr GetPreNode(const AnfNodePtr &node);
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
const FuncGraphPtr &root);
void SetStridedSliceStrategy(const AnfNodePtr &node);

View File

@ -376,8 +376,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.check_bprop_eliminate_,
irpass.switch_layer_defer_inline_,
irpass.replace_applicator_,
irpass.mirror_mini_step_elim_,
irpass.virtual_add_elim_,
irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_,
irpass.micro_step_allgather_replace_,
@ -493,9 +491,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.environ_get_depend_swap_,
irpass.environ_add_const_eliminate_,
irpass.value_based_eliminate_,
irpass.virtual_accu_grad_,
irpass.virtual_assign_add_,
irpass.mirror_micro_step_},
irpass.parallel_virtual_node_},
false, true);
opt::OptPassConfig b_2 = opt::OptPassConfig({
irpass.row_tensor_eliminate_,

View File

@ -913,6 +913,7 @@ GVAR_DEF(PrimitivePtr, kPrimFusedPullWeight, std::make_shared<Primitive>("FusedP
GVAR_DEF(PrimitivePtr, kPrimInitDataSetQueue, std::make_shared<Primitive>("InitDataSetQueue"));
GVAR_DEF(PrimitivePtr, kPrimVirtualAssignAdd, std::make_shared<Primitive>("_VirtualAssignAdd"));
GVAR_DEF(PrimitivePtr, kPrimVirtualAccuGrad, std::make_shared<Primitive>("_VirtualAccuGrad"));
GVAR_DEF(PrimitivePtr, kPrimVirtualPipelineEnd, std::make_shared<Primitive>("_VirtualPipelineEnd"));
GVAR_DEF(PrimitivePtr, kPrimMirrorMicroStep, std::make_shared<Primitive>("_MirrorMicroStepOperator"));
GVAR_DEF(PrimitivePtr, kPrimApplyProximalAdagrad, std::make_shared<Primitive>("ApplyProximalAdagrad"));
GVAR_DEF(PrimitivePtr, kPrimStreamSend, std::make_shared<Primitive>("StreamSend"));

View File

@ -18,6 +18,7 @@ from mindspore import context
from .._grad.grad_base import bprop_getters
from ..operations import _inner_ops as inner
from ..operations import _grad_ops as G
from ..operations.comm_ops import _VirtualPipelineEnd
from .. import functional as F
from .. import operations as P
from ..composite.multitype_ops.zeros_like_impl import zeros_like
@ -59,6 +60,17 @@ def get_bprop_roll(self):
return bprop
@bprop_getters.register(_VirtualPipelineEnd)
def get_bprop_virtual_pipeline_end(self):
"""Backpropagator for _VirtualPipelineEnd."""
grad = _VirtualPipelineEnd()
def bprop(x, out, dout):
dx = grad(dout)
return (dx,)
return bprop
@bprop_getters.register(inner.DynamicResizeNearestNeighbor)
def get_bprop_dynamic_resize_nearest_neighbor(self):
"""Generate bprop for DynamicResizeNearestNeighbor"""

View File

@ -51,7 +51,8 @@ from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV
Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather,
_VirtualPipelineEnd)
from .control_ops import GeSwitch, Merge
from .custom_ops import (Custom)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,

View File

@ -987,6 +987,28 @@ class _VirtualDiv(PrimitiveWithInfer):
virtual_div = _VirtualDiv()
class _VirtualPipelineEnd(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward and backward, mark end node in pipeline parallel.
Args:
divisor: float32
"""
@prim_attr_register
def __init__(self):
"""Initialize _VirtualPipelineEnd."""
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
return x_dtype
virtual_pipeline_end = _VirtualPipelineEnd()
class _VirtualAdd(PrimitiveWithInfer):
"""Auto parallel virtual operator. Do nothing in forward, do Add in backward."""

View File

@ -0,0 +1,105 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.train.model import Model
from mindspore.nn.wrap.cell_wrapper import PipelineCell, Cell
class DatasetLenet():
def __init__(self, data, length=3):
self.data = data
self.index = 1
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.data
@staticmethod
def get_dataset_size():
return 32
@staticmethod
def get_repeat_count():
return 1
@staticmethod
def get_batch_size():
return 32
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
return self
def reset(self):
self.index = 0
class MatMulCell(Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
def construct(self, x):
out = self.matmul(x, self.param)
out = self.matmul1(out, self.param1)
return out
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.block = nn.CellList()
for i in range(2):
cell = MatMulCell(strategy1, strategy2)
cell.pipeline_stage = i
self.block.append(cell)
def construct(self, x):
for i in range(2):
x = self.block[i](x)
return x
def test_pipeline_split_stage1():
"""
Feature:pipeline stage1
Description:pipeline end virtual node
Expectation:success
"""
context.set_auto_parallel_context(device_num=16, global_rank=8, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
dataset = DatasetLenet(data, 3)
strategy1 = ((4, 1), (1, 2))
strategy2 = ((2, 2), (2, 1))
net = PipelineCell(Net(strategy1, strategy2), 4)
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)