auto parallel adasum support data parallel and hybrid parallel

This commit is contained in:
yao_yf 2022-02-22 09:40:47 +08:00
parent f21b6914a3
commit 19236b1a70
7 changed files with 178 additions and 19 deletions

View File

@ -511,6 +511,8 @@ constexpr char REF_TO_EMBED[] = "RefToEmbed";
constexpr char STOP_GRADIENT[] = "stop_gradient";
constexpr char UPDATESTATE[] = "UpdateState";
constexpr char LOAD[] = "Load";
constexpr char OPPOSITE_RANK[] = "opposite_rank";
constexpr char TARGET_PARAM[] = "target_param";
// Batch parallel black list
constexpr char TENSOR_SCATTER_UPDATE[] = "TensorScatterUpdate";

View File

@ -712,7 +712,7 @@ RankList GetRankListByLayout(const std::shared_ptr<TensorLayout> &target_param_l
std::vector<bool> IsBorderAdaSumSendReceive(const AnfNodePtr &node, const RankList &group_devices) {
bool is_send = IsPrimitiveCNode(node, prim::kPrimSend);
PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr("opposite_rank"));
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
int64_t rank = g_device_manager->global_rank();
int64_t adasum_rank_distance = (group_devices.back() - group_devices.front()) / (group_devices.size() - 1);
if (adasum_rank_distance < ADASUM_MIN_DIS) {
@ -882,7 +882,7 @@ void HandleAdaSumPureModelParallel(const AnfNodePtr &node) {
return;
}
PrimitivePtr send_rec_prim = GetCNodePrimitive(node);
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr("opposite_rank"));
int64_t origin_dest_rank = GetValue<int64_t>(send_rec_prim->GetAttr(OPPOSITE_RANK));
int64_t rank = g_device_manager->global_rank();
CNodePtr cnode = node->cast<CNodePtr>();
auto pre_cnode = RealInputNode(cnode, 1);
@ -897,7 +897,8 @@ void HandleAdaSumPureModelParallel(const AnfNodePtr &node) {
AnfNodeIndexSet squeeze_input_node_user_set = manager->node_users()[squeeze_input];
for (auto &squeeze_input_user : squeeze_input_node_user_set) {
if (IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimSqueeze) ||
IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimUpdateState)) {
IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimUpdateState) ||
IsPrimitiveCNode(squeeze_input_user.first, prim::kPrimMakeTuple)) {
continue;
}
manager->Replace(squeeze_input_user.first, squeeze_input);
@ -923,10 +924,10 @@ bool HandleAdaSum(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_n
std::string target_param;
CNodePtr cnode = node->cast<CNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)->cast<ValueNodePtr>());
if (!prim->HasAttr("target_param")) {
if (!prim->HasAttr(TARGET_PARAM)) {
continue;
}
target_param = GetValue<std::string>(prim->GetAttr("target_param"));
target_param = GetValue<std::string>(prim->GetAttr(TARGET_PARAM));
auto target_param_layout = (*adasum_param_tensor_layout_map)[target_param];
RankList group_devices = GetRankListByLayout(target_param_layout);
// only model parallel

View File

@ -971,6 +971,9 @@ std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &
continue;
}
use_apply = SkipTrivialNodesMoveDown(manager, use_apply);
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
continue;
}
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node);
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();

View File

@ -29,6 +29,7 @@ from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Send, Receive
from mindspore.common.tensor import Tensor
from mindspore.common import dtype as mstype
from mindspore.communication.management import create_group
__all__ = ["AdaSumByDeltaWeightWrapCell", "AdaSumByGradWrapCell"]
@ -184,7 +185,8 @@ class _AdaSum(Cell):
self.parameter_divisibility_list = []
self.allreduce_node_num_list = []
last_delta_weights = []
fusion_attr = "fusion" if context.get_auto_parallel_context("parallel_mode") \
in ["data_parallel", "hybrid_parallel"] else "origin_fusion"
for step in range(self.calc_times):
current_group = self.device_number * (2 ** step)
sr_target = self.rank
@ -208,13 +210,13 @@ class _AdaSum(Cell):
for shape, dtype, name in left_delta_weights:
send_tag = self._hash(step, sr_target, weights_index)
send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
send.add_prim_attr("origin_fusion", fusion_id)
send.add_prim_attr(fusion_attr, fusion_id)
send.add_prim_attr("opposite_rank", dest_target)
send.add_prim_attr("target_param", name)
recv_tag = self._hash(step, dest_target, weights_index)
recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
group="hccl_world_group")
recv.add_prim_attr("origin_fusion", fusion_id)
recv.add_prim_attr(fusion_attr, fusion_id)
recv.add_prim_attr("opposite_rank", dest_target)
recv.add_prim_attr("target_param", name)
send_left.append(send)
@ -223,19 +225,18 @@ class _AdaSum(Cell):
for shape, dtype, name in right_delta_weights:
send_tag = self._hash(step, sr_target, weights_index)
send = Send(sr_tag=send_tag, dest_rank=dest_target, group="hccl_world_group")
send.add_prim_attr("origin_fusion", fusion_id + 1)
send.add_prim_attr(fusion_attr, fusion_id + 1)
send.add_prim_attr("opposite_rank", dest_target)
send.add_prim_attr("target_param", name)
recv_tag = self._hash(step, dest_target, weights_index)
recv = Receive(sr_tag=recv_tag, src_rank=dest_target, shape=shape, dtype=dtype,
group="hccl_world_group")
recv.add_prim_attr("origin_fusion", fusion_id + 1)
recv.add_prim_attr(fusion_attr, fusion_id + 1)
recv.add_prim_attr("opposite_rank", dest_target)
recv.add_prim_attr("target_param", name)
send_right.append(send)
recv_right.append(recv)
weights_index += 1
if self.send_node and self.send_node[-1]:
self.send_list_forward.append(send_left)
self.send_list_rollback.append(send_right)
@ -250,17 +251,21 @@ class _AdaSum(Cell):
last_delta_weights = left_delta_weights
param_allreduce_list = []
neighbor_ids = []
rank_ids = []
for index in range(2 ** (step + 1)):
node_rank = self.rank // self.device_number
double_d = 2 ** (step + 1)
neighbor_id = (node_rank // double_d * double_d + index) * self.device_number + \
self.rank % self.device_number
neighbor_ids.append(str(neighbor_id))
rank_ids.append(neighbor_id)
group_name = "-".join(neighbor_ids)
if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]:
create_group(group_name, rank_ids)
for parameter in self.parameter_tuple:
allreduce = P.AllReduce("sum", group_name)
allreduce.add_prim_attr("target_param", "adasum_delta_weight." + parameter.name)
allreduce.add_prim_attr("origin_fusion", fusion_id + 2)
allreduce.add_prim_attr(fusion_attr, fusion_id + 2)
allreduce.add_prim_attr("step", step)
param_allreduce_list.append(allreduce)
self.allreduce_list.append(param_allreduce_list)
@ -358,15 +363,18 @@ def _clone_weight_process(scale, weight):
return scale_mul(weight, scale)
def _parallel_check():
if context.get_auto_parallel_context("parallel_mode") not in ["semi_auto_parallel",
"auto_parallel", "data_parallel"]:
raise RuntimeError("Stand alone and hybrid parallel mode is not supported to apply adasum.")
if context.get_auto_parallel_context("parallel_mode") == "data_parallel":
logger.warning("For data parallel mode, it is recommended to using mindspore.boost to enable adasum.")
"""Parallel infos checking"""
if context.get_auto_parallel_context("parallel_mode") == "stand_alone":
raise RuntimeError("Stand alone mode is not supported to apply adasum.")
if context.get_auto_parallel_context("parallel_mode") in ["data_parallel", "hybrid_parallel"]:
logger.warning("For data parallel mode or hybrid parallel mode, "
"it is recommended to using mindspore.boost to enable adasum.")
if context.get_auto_parallel_context("enable_parallel_optimizer"):
raise RuntimeError("Currently, the optimizer shard is not supported with applying adasum.")
if context.get_auto_parallel_context("pipeline_stages") > 1:
raise RuntimeError("Currently, the pipeline parallel is not supported with applying adasum.")
if _get_stage_device_num() < 16:
raise RuntimeError("The device_num should be at least 16 when applying adasum.")
class AdaSumByGradWrapCell(Cell):
r"""

View File

@ -356,7 +356,18 @@ class TrainOneStepCell(Cell):
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
if isinstance(self.optimizer, (nn.AdaSumByGradWrapCell, nn.AdaSumByDeltaWeightWrapCell)):
from mindspore.communication.management import get_group_size, create_group, get_rank
group_number = get_group_size() // 8
self.degree = int(self.degree / group_number)
group_list = [list(range(x * self.degree, (x + 1) * self.degree)) for x in range(group_number)]
current_index = get_rank() // 8
server_group_name = "allreduce_" + str(current_index)
create_group(server_group_name, group_list[current_index])
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree,
group=server_group_name)
else:
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
def construct(self, *inputs):
loss = self.network(*inputs)

View File

@ -393,7 +393,7 @@ class DistributedGradReducer(Cell):
fusion_type)
if not param_fusion:
self.split_fusion = False
self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type)
self.allreduce = AllReduce('sum', group).add_prim_attr('fusion', fusion_type)
self.allgather = AllGather(group)
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in parameters)

View File

@ -0,0 +1,134 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum, AdaSumByDeltaWeightWrapCell, AdaSumByGradWrapCell
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.mul = P.Mul().shard(strategy1)
self.matmul = P.MatMul().shard(strategy2)
self.gather = P.Gather().shard(strategy3)
self.reduce_sum = P.ReduceSum()
self.mul_weight = Parameter(Tensor(np.ones([64, 32]), dtype=ms.float32), "w1")
self.matmul_weight = Parameter(Tensor(np.ones([32, 32]), dtype=ms.float32), "w2")
self.embedding_table = Parameter(Tensor(np.ones([64, 32]), dtype=ms.float32), "embedding_table")
def construct(self, x, b):
out = self.gather(self.embedding_table, x, 0)
out = self.matmul(out, self.matmul_weight)
out = self.mul(out, self.mul_weight)
out = out + b
return self.reduce_sum(out)
_x = Tensor(np.ones([64]), dtype=ms.int32)
_b = Tensor(np.ones([64, 32]), dtype=ms.float32)
def compile_net(net, by_grad=True):
if by_grad:
optimizer = AdaSumByGradWrapCell(Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9))
else:
optimizer = AdaSumByDeltaWeightWrapCell(Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9))
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
train_net.set_train()
_cell_graph_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_auto_parallel_adasum1():
"""
Feature: adasum in auto parallel.
Description: verify adasum by mul/matmul/gather, rank0, dp, mp, not_full_dp
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
mul_strategy1 = ((8, 4), (8, 4))
matmul_strategy2 = ((8, 1), (1, 1))
gather_strategy3 = ((1, 1), (32,))
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
compile_net(net)
def test_auto_parallel_adasum2():
"""
Feature: adasum in auto parallel.
Description: verify adasum by mul/matmul/gather, rank0, dp, mp, not_full_dp
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
mul_strategy1 = ((8, 4), (8, 4))
matmul_strategy2 = ((8, 1), (1, 1))
gather_strategy3 = ((1, 1), (32,))
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
compile_net(net, by_grad=False)
def test_auto_parallel_adasum3():
"""
Feature: adasum in auto parallel.
Description: verify adasum by mul/matmul/gather, rank0, mix_dp_mp, mp
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
mul_strategy1 = ((8, 4), (8, 4))
matmul_strategy2 = ((8, 4), (4, 1))
gather_strategy3 = ((32, 1), (1,))
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
compile_net(net)
def test_auto_parallel_adasum4():
"""
Feature: adasum in auto parallel.
Description: verify adasum by mul/matmul/gather, rank0, mix_dp_mp, mp
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
mul_strategy1 = ((8, 4), (8, 4))
matmul_strategy2 = ((8, 4), (4, 1))
gather_strategy3 = ((32, 1), (1,))
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
compile_net(net, by_grad=False)
def test_auto_parallel_adasum5():
"""
Feature: adasum in auto parallel.
Description: verify adasum by mul/matmul/gather, rank16, dp, mp, not_full_dp
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=16)
mul_strategy1 = ((8, 4), (8, 4))
matmul_strategy2 = ((8, 1), (1, 1))
gather_strategy3 = ((1, 1), (32,))
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
compile_net(net)
def test_auto_parallel_adasum6():
"""
Feature: adasum in auto parallel.
Description: verify adasum by mul/matmul/gather, rank16, dp, mp, not_full_dp
Expectation: compile done without error.
"""
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=16)
mul_strategy1 = ((8, 4), (8, 4))
matmul_strategy2 = ((8, 1), (1, 1))
gather_strategy3 = ((1, 1), (32,))
net = Net(mul_strategy1, matmul_strategy2, gather_strategy3)
compile_net(net, by_grad=False)