forked from mindspore-Ecosystem/mindspore
auto parallel adasum support data parallel and hybrid parallel
This commit is contained in:
parent
f21b6914a3
commit
19236b1a70
|
@ -511,6 +511,8 @@ constexpr char REF_TO_EMBED[] = "RefToEmbed";
|
|||
constexpr char STOP_GRADIENT[] = "stop_gradient";
|
||||
constexpr char UPDATESTATE[] = "UpdateState";
|
||||
constexpr char LOAD[] = "Load";
|
||||
constexpr char OPPOSITE_RANK[] = "opposite_rank";
|
||||
constexpr char TARGET_PARAM[] = "target_param";
|
||||
|
||||
// Batch parallel black list
|
||||
constexpr char TENSOR_SCATTER_UPDATE[] = "TensorScatterUpdate";
|
||||
|
|
|
@ -712,7 +712,7 @@ RankList GetRankListByLayout(const std::shared_ptr<TensorLayout> &target_param_l
|
|||
std::vector<bool> IsBorderAdaSumSendReceive(const AnfNodePtr &node, const RankList &group_devices) {
|
||||
bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
|
||||
PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
|
||||
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr("opposite_rank"));
|
||||
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
|
||||
if (adasum_rank_distance < ADASUM_MIN_DIS) {
|
||||
|
@ -882,7 +882,7 @@ void HandleAdaSumPureModelParallel(const AnfNodePtr &node) {
|
|||
return;
|
||||
}
|
||||
PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
|
||||
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr("opposite_rank"));
|
||||
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
|
||||
int64_t rank = g_device_manager->global_rank();
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
auto pre_cnode = RealInputNode(cnode, 1);
|
||||
|
@ -897,7 +897,8 @@ void HandleAdaSumPureModelParallel(const AnfNodePtr &node) {
|
|||
AnfNodeIndexSet squeeze_input_node_user_set = manager->node_users()[squeeze_input];
|
||||
for (auto &squeeze_input_user : squeeze_input_node_user_set) {
|
||||
if (IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimSqueeze) ||
|
||||
IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimUpdateState)) {
|
||||
IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimUpdateState) ||
|
||||
IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimMakeTuple)) {
|
||||
continue;
|
||||
}
|
||||
manager->Replace(squeeze_input_user.first, squeeze_input);
|
||||
|
@ -923,10 +924,10 @@ bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_n
|
|||
std::string target_param;
|
||||
CNodePtr cnode = node->cast<CNodePtr>();
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
|
||||
if (!prim->HasAttr("target_param")) {
|
||||
if (!prim->HasAttr(TARGET_PARAM)) {
|
||||
continue;
|
||||
}
|
||||
target_param = GetValue<std::string>(prim->GetAttr("target_param"));
|
||||
target_param = GetValue<std::string>(prim->GetAttr(TARGET_PARAM));
|
||||
auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
|
||||
RankList group_devices = GetRankListByLayout(target_param_layout);
|
||||
// only model parallel
|
||||
|
|
|
@ -971,6 +971,9 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
|
|||
continue;
|
||||
}
|
||||
use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
|
||||
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
||||
continue;
|
||||
}
|
||||
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
||||
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.ops import operations as P
|
|||
from mindspore.ops.operations._inner_ops import Send, Receive
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.communication.management import create_group
|
||||
|
||||
__all__ = ["AdaSumByDeltaWeightWrapCell", "AdaSumByGradWrapCell"]
|
||||
|
||||
|
@ -184,7 +185,8 @@ class _AdaSum(Cell):
|
|||
self.parameter_divisibility_list = []
|
||||
self.allreduce_node_num_list = []
|
||||
last_delta_weights = []
|
||||
|
||||
fusion_attr = "fusion" if context.get_auto_parallel_context("parallel_mode") \
|
||||
in ["data_parallel", "hybrid_parallel"] else "origin_fusion"
|
||||
for step in range(self.calc_times):
|
||||
current_group = self.device_number * (2 ** step)
|
||||
sr_target = self.rank
|
||||
|
@ -208,13 +210,13 @@ class _AdaSum(Cell):
|
|||
for shape, dtype, name in left_delta_weights:
|
||||
send_tag = self._hash(step, sr_target, weights_index)
|
||||
send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
|
||||
send.add_prim_attr("origin_fusion", fusion_id)
|
||||
send.add_prim_attr(fusion_attr, fusion_id)
|
||||
send.add_prim_attr("opposite_rank", dest_target)
|
||||
send.add_prim_attr("target_param", name)
|
||||
recv_tag = self._hash(step, dest_target, weights_index)
|
||||
recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
|
||||
group="hccl_world_group")
|
||||
recv.add_prim_attr("origin_fusion", fusion_id)
|
||||
recv.add_prim_attr(fusion_attr, fusion_id)
|
||||
recv.add_prim_attr("opposite_rank", dest_target)
|
||||
recv.add_prim_attr("target_param", name)
|
||||
send_left.append(send)
|
||||
|
@ -223,19 +225,18 @@ class _AdaSum(Cell):
|
|||
for shape, dtype, name in right_delta_weights:
|
||||
send_tag = self._hash(step, sr_target, weights_index)
|
||||
send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
|
||||
send.add_prim_attr("origin_fusion", fusion_id + 1)
|
||||
send.add_prim_attr(fusion_attr, fusion_id + 1)
|
||||
send.add_prim_attr("opposite_rank", dest_target)
|
||||
send.add_prim_attr("target_param", name)
|
||||
recv_tag = self._hash(step, dest_target, weights_index)
|
||||
recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
|
||||
group="hccl_world_group")
|
||||
recv.add_prim_attr("origin_fusion", fusion_id + 1)
|
||||
recv.add_prim_attr(fusion_attr, fusion_id + 1)
|
||||
recv.add_prim_attr("opposite_rank", dest_target)
|
||||
recv.add_prim_attr("target_param", name)
|
||||
send_right.append(send)
|
||||
recv_right.append(recv)
|
||||
weights_index += 1
|
||||
|
||||
if self.send_node and self.send_node[-1]:
|
||||
self.send_list_forward.append(send_left)
|
||||
self.send_list_rollback.append(send_right)
|
||||
|
@ -250,17 +251,21 @@ class _AdaSum(Cell):
|
|||
last_delta_weights = left_delta_weights
|
||||
param_allreduce_list = []
|
||||
neighbor_ids = []
|
||||
rank_ids = []
|
||||
for index in range(2 ** (step + 1)):
|
||||
node_rank = self.rank // self.device_number
|
||||
double_d = 2 ** (step + 1)
|
||||
neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \
|
||||
self.rank % self.device_number
|
||||
neighbor_ids.append(str(neighbor_id))
|
||||
rank_ids.append(neighbor_id)
|
||||
group_name = "-".join(neighbor_ids)
|
||||
if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]:
|
||||
create_group(group_name, rank_ids)
|
||||
for parameter in self.parameter_tuple:
|
||||
allreduce = P.AllReduce("sum", group_name)
|
||||
allreduce.add_prim_attr("target_param", "adasum_delta_weight." + parameter.name)
|
||||
allreduce.add_prim_attr("origin_fusion", fusion_id + 2)
|
||||
allreduce.add_prim_attr(fusion_attr, fusion_id + 2)
|
||||
allreduce.add_prim_attr("step", step)
|
||||
param_allreduce_list.append(allreduce)
|
||||
self.allreduce_list.append(param_allreduce_list)
|
||||
|
@ -358,15 +363,18 @@ def _clone_weight_process(scale, weight):
|
|||
return scale_mul(weight, scale)
|
||||
|
||||
def _parallel_check():
|
||||
if context.get_auto_parallel_context("parallel_mode") not in ["semi_auto_parallel",
|
||||
"auto_parallel", "data_parallel"]:
|
||||
raise RuntimeError("Stand alone and hybrid parallel mode is not supported to apply adasum.")
|
||||
if context.get_auto_parallel_context("parallel_mode") == "data_parallel":
|
||||
logger.warning("For data parallel mode, it is recommended to using mindspore.boost to enable adasum.")
|
||||
"""Parallel infos checking"""
|
||||
if context.get_auto_parallel_context("parallel_mode") == "stand_alone":
|
||||
raise RuntimeError("Stand alone mode is not supported to apply adasum.")
|
||||
if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]:
|
||||
logger.warning("For data parallel mode or hybrid parallel mode, "
|
||||
"it is recommended to using mindspore.boost to enable adasum.")
|
||||
if context.get_auto_parallel_context("enable_parallel_optimizer"):
|
||||
raise RuntimeError("Currently, the optimizer shard is not supported with applying adasum.")
|
||||
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
||||
raise RuntimeError("Currently, the pipeline parallel is not supported with applying adasum.")
|
||||
if _get_stage_device_num() < 16:
|
||||
raise RuntimeError("The device_num should be at least 16 when applying adasum.")
|
||||
|
||||
class AdaSumByGradWrapCell(Cell):
|
||||
r"""
|
||||
|
|
|
@ -356,7 +356,18 @@ class TrainOneStepCell(Cell):
|
|||
if self.reducer_flag:
|
||||
self.mean = _get_gradients_mean()
|
||||
self.degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
||||
if isinstance(self.optimizer, (nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)):
|
||||
from mindspore.communication.management import get_group_size, create_group, get_rank
|
||||
group_number = get_group_size() // 8
|
||||
self.degree = int(self.degree / group_number)
|
||||
group_list = [list(range(x * self.degree, (x + 1) * self.degree)) for x in range(group_number)]
|
||||
current_index = get_rank() // 8
|
||||
server_group_name = "allreduce_" + str(current_index)
|
||||
create_group(server_group_name, group_list[current_index])
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree,
|
||||
group=server_group_name)
|
||||
else:
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
||||
|
||||
def construct(self, *inputs):
|
||||
loss = self.network(*inputs)
|
||||
|
|
|
@ -393,7 +393,7 @@ class DistributedGradReducer(Cell):
|
|||
fusion_type)
|
||||
if not param_fusion:
|
||||
self.split_fusion = False
|
||||
self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type)
|
||||
self.allreduce = AllReduce('sum', group).add_prim_attr('fusion', fusion_type)
|
||||
self.allgather = AllGather(group)
|
||||
ps_filter = lambda x: x.is_param_ps
|
||||
self.ps_parameters = tuple(ps_filter(x) for x in parameters)
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
from mindspore.nn import Cell, TrainOneStepCell, Momentum, AdaSumByDeltaWeightWrapCell, AdaSumByGradWrapCell
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, strategy1=None, strategy2=None, strategy3=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy1)
|
||||
self.matmul = P.MatMul().shard(strategy2)
|
||||
self.gather = P.Gather().shard(strategy3)
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.mul_weight = Parameter(Tensor(np.ones([64, 32]), dtype=ms.float32), "w1")
|
||||
self.matmul_weight = Parameter(Tensor(np.ones([32, 32]), dtype=ms.float32), "w2")
|
||||
self.embedding_table = Parameter(Tensor(np.ones([64, 32]), dtype=ms.float32), "embedding_table")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.gather(self.embedding_table, x, 0)
|
||||
out = self.matmul(out, self.matmul_weight)
|
||||
out = self.mul(out, self.mul_weight)
|
||||
out = out + b
|
||||
return self.reduce_sum(out)
|
||||
|
||||
|
||||
_x = Tensor(np.ones([64]), dtype=ms.int32)
|
||||
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net, by_grad=True):
|
||||
if by_grad:
|
||||
optimizer = AdaSumByGradWrapCell(Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9))
|
||||
else:
|
||||
optimizer = AdaSumByDeltaWeightWrapCell(Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9))
|
||||
train_net = TrainOneStepCell(net, optimizer)
|
||||
train_net.set_auto_parallel()
|
||||
train_net.set_train()
|
||||
_cell_graph_executor.compile(train_net, _x, _b)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_auto_parallel_adasum1():
|
||||
"""
|
||||
Feature: adasum in auto parallel.
|
||||
Description: verify adasum by mul/matmul/gather, rank0, dp, mp, not_full_dp
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||
mul_strategy1 = ((8, 4), (8, 4))
|
||||
matmul_strategy2 = ((8, 1), (1, 1))
|
||||
gather_strategy3 = ((1, 1), (32,))
|
||||
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
|
||||
compile_net(net)
|
||||
|
||||
def test_auto_parallel_adasum2():
|
||||
"""
|
||||
Feature: adasum in auto parallel.
|
||||
Description: verify adasum by mul/matmul/gather, rank0, dp, mp, not_full_dp
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||
mul_strategy1 = ((8, 4), (8, 4))
|
||||
matmul_strategy2 = ((8, 1), (1, 1))
|
||||
gather_strategy3 = ((1, 1), (32,))
|
||||
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
|
||||
compile_net(net, by_grad=False)
|
||||
|
||||
def test_auto_parallel_adasum3():
|
||||
"""
|
||||
Feature: adasum in auto parallel.
|
||||
Description: verify adasum by mul/matmul/gather, rank0, mix_dp_mp, mp
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||
mul_strategy1 = ((8, 4), (8, 4))
|
||||
matmul_strategy2 = ((8, 4), (4, 1))
|
||||
gather_strategy3 = ((32, 1), (1,))
|
||||
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
|
||||
compile_net(net)
|
||||
|
||||
def test_auto_parallel_adasum4():
|
||||
"""
|
||||
Feature: adasum in auto parallel.
|
||||
Description: verify adasum by mul/matmul/gather, rank0, mix_dp_mp, mp
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||
mul_strategy1 = ((8, 4), (8, 4))
|
||||
matmul_strategy2 = ((8, 4), (4, 1))
|
||||
gather_strategy3 = ((32, 1), (1,))
|
||||
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
|
||||
compile_net(net, by_grad=False)
|
||||
|
||||
def test_auto_parallel_adasum5():
|
||||
"""
|
||||
Feature: adasum in auto parallel.
|
||||
Description: verify adasum by mul/matmul/gather, rank16, dp, mp, not_full_dp
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=16)
|
||||
mul_strategy1 = ((8, 4), (8, 4))
|
||||
matmul_strategy2 = ((8, 1), (1, 1))
|
||||
gather_strategy3 = ((1, 1), (32,))
|
||||
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
|
||||
compile_net(net)
|
||||
|
||||
def test_auto_parallel_adasum6():
|
||||
"""
|
||||
Feature: adasum in auto parallel.
|
||||
Description: verify adasum by mul/matmul/gather, rank16, dp, mp, not_full_dp
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=16)
|
||||
mul_strategy1 = ((8, 4), (8, 4))
|
||||
matmul_strategy2 = ((8, 1), (1, 1))
|
||||
gather_strategy3 = ((1, 1), (32,))
|
||||
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
|
||||
compile_net(net, by_grad=False)
|
Loading…
Reference in New Issue