forked from mindspore-Ecosystem/mindspore
!40455 [AutoParallel]optimizer Pipeline parallel
Merge pull request !40455 from lichen/opt_pipeline_parallel
This commit is contained in:
commit
d222d7393c
|
@ -116,13 +116,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
|
||||
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
|
||||
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
|
||||
mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
|
||||
prim::kPrimMirrorMiniStep);
|
||||
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
|
||||
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
|
||||
micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(),
|
||||
"micro_step_allgather_replace", prim::kPrimMicroStepAllGather);
|
||||
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual_add", prim::kPrimVirtualAdd);
|
||||
check_bprop_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ =
|
||||
|
@ -216,18 +213,16 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Virtual Dataset
|
||||
virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
|
||||
"virtual_dataset_eliminate", prim::kPrimVirtualDataset);
|
||||
// Virtual Dataset
|
||||
|
||||
// Virtual Output
|
||||
virtual_output_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
|
||||
|
||||
// PipelineSplit
|
||||
receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
|
||||
virtual_accu_grad_ =
|
||||
MakeSubstitution(std::make_shared<VirtualAccuGradEliminater>(), "virtual_accu_grad", prim::kPrimVirtualAccuGrad);
|
||||
virtual_assign_add_ =
|
||||
MakeSubstitution(std::make_shared<VirtualAssignAddEliminater>(), "virtual_assign_add", prim::kPrimVirtualAssignAdd);
|
||||
mirror_micro_step_ =
|
||||
MakeSubstitution(std::make_shared<MirrorMicroStepEliminater>(), "mirror_micro_step", prim::kPrimMirrorMicroStep);
|
||||
parallel_virtual_node_ =
|
||||
MakeSubstitution(std::make_shared<ParallelVirtualNodeEliminater>(), "parallel_virtual_node",
|
||||
{prim::kPrimVirtualAssignAdd, prim::kPrimVirtualPipelineEnd, prim::kPrimVirtualAccuGrad,
|
||||
prim::kPrimMirrorMicroStep, prim::kPrimVirtualAdd, prim::kPrimMirrorMiniStep});
|
||||
|
||||
// Convert
|
||||
print_tuple_wrapper_ =
|
||||
|
|
|
@ -60,8 +60,6 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr reset_defer_inline_;
|
||||
SubstitutionPtr depend_value_elim_;
|
||||
SubstitutionPtr all_reduce_const_elim_;
|
||||
SubstitutionPtr mirror_mini_step_elim_;
|
||||
SubstitutionPtr virtual_add_elim_;
|
||||
SubstitutionPtr mini_step_allgather_replace_;
|
||||
SubstitutionPtr micro_step_allgather_replace_;
|
||||
SubstitutionPtr real_op_eliminate_;
|
||||
|
@ -128,10 +126,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr virtual_output_eliminate_;
|
||||
|
||||
// PipelineSplit
|
||||
SubstitutionPtr receive_eliminate_;
|
||||
SubstitutionPtr virtual_accu_grad_;
|
||||
SubstitutionPtr virtual_assign_add_;
|
||||
SubstitutionPtr mirror_micro_step_;
|
||||
SubstitutionPtr parallel_virtual_node_;
|
||||
|
||||
// Convert
|
||||
SubstitutionPtr print_tuple_wrapper_;
|
||||
|
|
|
@ -101,62 +101,35 @@ class VirtualDatasetEliminater : public AnfVisitor {
|
|||
|
||||
return node->func_graph()->NewCNode(args);
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimVirtualOutput, X} -> X
|
||||
class VirtualOutputEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->inputs().size() <= 1) {
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return cnode->input(1);
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimReceive, X} -> prim::kPrimReceive
|
||||
class ReceiveEliminater : public AnfVisitor {
|
||||
// {ParallelVirtualNode, X, Y...} -> X
|
||||
class ParallelVirtualNodeEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->inputs().size() == 1) {
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> args = {cnode->input(0)};
|
||||
return node->func_graph()->NewCNode(args);
|
||||
auto input = cnode->input(1);
|
||||
if (input->isa<CNode>()) {
|
||||
auto input_cnode = input->cast<CNodePtr>();
|
||||
input_cnode->set_primal_attrs(cnode->primal_attrs());
|
||||
}
|
||||
return cnode->input(1);
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
class VirtualAssignAddEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAssignAdd) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
class LinSpaceValue : public AnfVisitor {
|
||||
|
@ -197,44 +170,6 @@ class LinSpaceValue : public AnfVisitor {
|
|||
}
|
||||
};
|
||||
|
||||
class VirtualAccuGradEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAccuGrad) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
private:
|
||||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimMirrorMicroStep, X, Z} -> X
|
||||
class MirrorMicroStepEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimSameTypeShape, X, Y} -> X
|
||||
class SameEliminater : public AnfVisitor {
|
||||
public:
|
||||
|
@ -273,44 +208,6 @@ class CheckBpropEliminater : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimMirrorMiniStep, X, Z} -> X
|
||||
class MirrorMiniStepEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimVirtualAdd, X, Z} -> X
|
||||
class VirtualAddEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
|
||||
class MiniStepAllGatherPass : public AnfVisitor {
|
||||
public:
|
||||
|
@ -346,8 +243,6 @@ class MiniStepAllGatherPass : public AnfVisitor {
|
|||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
|
||||
|
@ -385,8 +280,6 @@ class MicroStepAllGatherPass : public AnfVisitor {
|
|||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
return new_node;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// Reset defer_inline flag
|
||||
|
|
|
@ -33,27 +33,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace parallel {
|
||||
const std::set<PrimitivePtr> END_NODE_BLACK_LIST = {
|
||||
prim::kPrimDepend, prim::kPrimTupleGetItem, prim::kPrimAdd, prim::kPrimSoftmaxCrossEntropyWithLogits,
|
||||
prim::kPrimMakeTuple, prim::kPrimUpdateState, prim::kPrimReshape};
|
||||
|
||||
static bool IsInEndNodeBlackList(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
||||
return true;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
if (IsInParallelBlackList(prim)) {
|
||||
return true;
|
||||
}
|
||||
for (auto prim_node = END_NODE_BLACK_LIST.cbegin(); prim_node != END_NODE_BLACK_LIST.cend(); ++prim_node) {
|
||||
if (IsPrimitiveCNode(cnode, *prim_node)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
AnfNodePtr FindAccuGrad(const CNodePtr &cnode) {
|
||||
auto pre_node = cnode->input(1);
|
||||
size_t depth = 0;
|
||||
|
@ -557,24 +536,19 @@ void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root)
|
|||
}
|
||||
}
|
||||
|
||||
AnfNodePtr GetPreNode(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> node_queue = {node};
|
||||
while (!node_queue.empty()) {
|
||||
auto cur_node = (*node_queue.cbegin())->cast<CNodePtr>();
|
||||
if (!cur_node) {
|
||||
(void)node_queue.erase(node_queue.cbegin());
|
||||
continue;
|
||||
}
|
||||
(void)node_queue.erase(node_queue.cbegin());
|
||||
if (!IsInEndNodeBlackList(cur_node) && cur_node->HasPrimalAttr(NEED_GRAD)) {
|
||||
MS_LOG(INFO) << "Pipeline End node: " << cur_node->DebugString();
|
||||
return cur_node;
|
||||
}
|
||||
(void)node_queue.insert(node_queue.cend(), cur_node->inputs().cbegin() + 1, cur_node->inputs().cend());
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Get Pipeline End node failed.";
|
||||
void InsertVirtualPipelineEndNode(const CNodePtr &cnode, const FuncGraphManagerPtr &manager, size_t index) {
|
||||
auto pre_cnode = cnode->input(index)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(pre_cnode);
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
OperatorAttrs attrs_;
|
||||
auto op = CreateOpInstance(attrs_, "_VirtualPipelineEnd", "end_node");
|
||||
auto value_node = NewValueNode(op);
|
||||
auto virtual_end = graph->NewCNode({value_node, pre_cnode});
|
||||
virtual_end->set_abstract(pre_cnode->abstract());
|
||||
virtual_end->AddPrimalAttr(PIPELINE_END, pre_cnode->GetPrimalAttr(MICRO));
|
||||
virtual_end->AddPrimalAttr(MICRO, pre_cnode->GetPrimalAttr(MICRO));
|
||||
manager->SetEdge(cnode, SizeToInt(index), virtual_end);
|
||||
}
|
||||
|
||||
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
|
||||
|
@ -593,7 +567,8 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
|
|||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
if (prim && prim->HasAttr(PIPELINE_END)) {
|
||||
for (auto &temp_node : cnode->inputs()) {
|
||||
for (size_t i = 0; i < cnode->inputs().size(); ++i) {
|
||||
auto temp_node = cnode->input(i);
|
||||
if (!temp_node->isa<CNode>()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -601,18 +576,7 @@ void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
|
|||
if (!temp_prim || temp_prim->HasAttr(PIPELINE_END)) {
|
||||
continue;
|
||||
}
|
||||
auto end_node = GetPreNode(temp_node);
|
||||
MS_EXCEPTION_IF_NULL(end_node);
|
||||
auto end_cnode = end_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(end_cnode);
|
||||
auto end_prim = GetCNodePrimitive(end_node);
|
||||
OperatorAttrs attrs_;
|
||||
auto op = CreateOpInstance(attrs_, end_prim->name(), "");
|
||||
auto value_node = NewValueNode(op);
|
||||
auto new_prim = GetValueNode(value_node)->cast<PrimitivePtr>();
|
||||
(void)new_prim->SetAttrs(end_prim->attrs());
|
||||
manager->SetEdge(end_node, 0, value_node);
|
||||
end_cnode->AddPrimalAttr(PIPELINE_END, end_cnode->GetPrimalAttr(MICRO));
|
||||
InsertVirtualPipelineEndNode(cnode, manager, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,7 +58,6 @@ void HandleMicroBatch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphM
|
|||
void BroadCastMicroBatch(const CNodePtr &node, NodeUsersMap *node_users_map, const ValuePtr &value, size_t max_depth);
|
||||
void LabelNeedGrad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &root);
|
||||
void BroadCastNeedGrad(const AnfNodePtr &node, NodeUsersMap *node_user_map, const FuncGraphPtr &root);
|
||||
AnfNodePtr GetPreNode(const AnfNodePtr &node);
|
||||
void LastStageEndNode(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphManagerPtr &manager,
|
||||
const FuncGraphPtr &root);
|
||||
void SetStridedSliceStrategy(const AnfNodePtr &node);
|
||||
|
|
|
@ -376,8 +376,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.check_bprop_eliminate_,
|
||||
irpass.switch_layer_defer_inline_,
|
||||
irpass.replace_applicator_,
|
||||
irpass.mirror_mini_step_elim_,
|
||||
irpass.virtual_add_elim_,
|
||||
irpass.row_tensor_add_zeros_like_,
|
||||
irpass.mini_step_allgather_replace_,
|
||||
irpass.micro_step_allgather_replace_,
|
||||
|
@ -493,9 +491,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.environ_get_depend_swap_,
|
||||
irpass.environ_add_const_eliminate_,
|
||||
irpass.value_based_eliminate_,
|
||||
irpass.virtual_accu_grad_,
|
||||
irpass.virtual_assign_add_,
|
||||
irpass.mirror_micro_step_},
|
||||
irpass.parallel_virtual_node_},
|
||||
false, true);
|
||||
opt::OptPassConfig b_2 = opt::OptPassConfig({
|
||||
irpass.row_tensor_eliminate_,
|
||||
|
|
|
@ -913,6 +913,7 @@ GVAR_DEF(PrimitivePtr, kPrimFusedPullWeight, std::make_shared<Primitive>("FusedP
|
|||
GVAR_DEF(PrimitivePtr, kPrimInitDataSetQueue, std::make_shared<Primitive>("InitDataSetQueue"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimVirtualAssignAdd, std::make_shared<Primitive>("_VirtualAssignAdd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimVirtualAccuGrad, std::make_shared<Primitive>("_VirtualAccuGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimVirtualPipelineEnd, std::make_shared<Primitive>("_VirtualPipelineEnd"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimMirrorMicroStep, std::make_shared<Primitive>("_MirrorMicroStepOperator"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimApplyProximalAdagrad, std::make_shared<Primitive>("ApplyProximalAdagrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimStreamSend, std::make_shared<Primitive>("StreamSend"));
|
||||
|
|
|
@ -18,6 +18,7 @@ from mindspore import context
|
|||
from .._grad.grad_base import bprop_getters
|
||||
from ..operations import _inner_ops as inner
|
||||
from ..operations import _grad_ops as G
|
||||
from ..operations.comm_ops import _VirtualPipelineEnd
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
|
@ -59,6 +60,17 @@ def get_bprop_roll(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_VirtualPipelineEnd)
|
||||
def get_bprop_virtual_pipeline_end(self):
|
||||
"""Backpropagator for _VirtualPipelineEnd."""
|
||||
grad = _VirtualPipelineEnd()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = grad(dout)
|
||||
return (dx,)
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.DynamicResizeNearestNeighbor)
|
||||
def get_bprop_dynamic_resize_nearest_neighbor(self):
|
||||
"""Generate bprop for DynamicResizeNearestNeighbor"""
|
||||
|
|
|
@ -51,7 +51,8 @@ from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV
|
|||
Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
||||
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather)
|
||||
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather,
|
||||
_VirtualPipelineEnd)
|
||||
from .control_ops import GeSwitch, Merge
|
||||
from .custom_ops import (Custom)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
|
|
|
@ -987,6 +987,28 @@ class _VirtualDiv(PrimitiveWithInfer):
|
|||
virtual_div = _VirtualDiv()
|
||||
|
||||
|
||||
class _VirtualPipelineEnd(PrimitiveWithInfer):
|
||||
"""
|
||||
Auto parallel virtual operator. Do nothing in forward and backward, mark end node in pipeline parallel.
|
||||
|
||||
Args:
|
||||
divisor: float32
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize _VirtualPipelineEnd."""
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
return x_dtype
|
||||
|
||||
|
||||
virtual_pipeline_end = _VirtualPipelineEnd()
|
||||
|
||||
|
||||
class _VirtualAdd(PrimitiveWithInfer):
|
||||
"""Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
|
||||
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell, Cell
|
||||
|
||||
|
||||
class DatasetLenet():
|
||||
def __init__(self, data, length=3):
|
||||
self.data = data
|
||||
self.index = 1
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.data
|
||||
|
||||
@staticmethod
|
||||
def get_dataset_size():
|
||||
return 32
|
||||
|
||||
@staticmethod
|
||||
def get_repeat_count():
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def get_batch_size():
|
||||
return 32
|
||||
|
||||
def create_tuple_iterator(self, num_epochs=1, do_copy=True):
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
class MatMulCell(Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
|
||||
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
|
||||
self.matmul = P.MatMul().shard(strategy1)
|
||||
self.matmul1 = P.MatMul().shard(strategy2)
|
||||
|
||||
def construct(self, x):
|
||||
out = self.matmul(x, self.param)
|
||||
out = self.matmul1(out, self.param1)
|
||||
return out
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, strategy1, strategy2):
|
||||
super().__init__()
|
||||
self.block = nn.CellList()
|
||||
for i in range(2):
|
||||
cell = MatMulCell(strategy1, strategy2)
|
||||
cell.pipeline_stage = i
|
||||
self.block.append(cell)
|
||||
|
||||
def construct(self, x):
|
||||
for i in range(2):
|
||||
x = self.block[i](x)
|
||||
return x
|
||||
|
||||
|
||||
def test_pipeline_split_stage1():
|
||||
"""
|
||||
Feature:pipeline stage1
|
||||
Description:pipeline end virtual node
|
||||
Expectation:success
|
||||
"""
|
||||
context.set_auto_parallel_context(device_num=16, global_rank=8, pipeline_stages=2)
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
|
||||
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
||||
dataset = DatasetLenet(data, 3)
|
||||
strategy1 = ((4, 1), (1, 2))
|
||||
strategy2 = ((2, 2), (2, 1))
|
||||
net = PipelineCell(Net(strategy1, strategy2), 4)
|
||||
optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
Loading…
Reference in New Issue