!18905 pangu_pipeline

Merge pull request !18905 from yao_yf/pangu_pipeline
This commit is contained in:
i-robot 2021-06-30 07:22:02 +00:00 committed by Gitee
commit abc77e13e9
9 changed files with 573 additions and 41 deletions

View File

@ -225,8 +225,9 @@ bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
if (rank_tag2 == nullptr) {
rank_tag2 = prim2->GetAttr(DEST_RANK);
}
MS_EXCEPTION_IF_NULL(rank_tag1);
MS_EXCEPTION_IF_NULL(rank_tag2);
if (!rank_tag1 || !rank_tag2) {
return false;
}
auto rank1_value = GetValue<int64_t>(rank_tag1);
auto rank2_value = GetValue<int64_t>(rank_tag2);
if (rank1_value == rank2_value) {

View File

@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop";
const char STAGE_NAME[] = "pipeline_stage";
const char STAGE_NAME[] = "_pipeline_stage";
// define the Namespace name
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace

View File

@ -154,6 +154,7 @@ class Parameter(Tensor_):
self._cast_type = None
self._unique = False
self.is_in_parallel = _is_in_parallel_mode()
self._pipeline_stage_list = []
if isinstance(default_input, (Tensor_, Tensor)):
Tensor_.__init__(self, default_input.dtype, default_input.shape)
elif isinstance(default_input, int):
@ -452,6 +453,9 @@ class Parameter(Tensor_):
new_param.param_info = self.param_info
return new_param
def add_pipeline_stage(self, stage):
self._pipeline_stage_list.append(stage)
def set_data(self, data, slice_shape=False):
"""
Set Parameter's data.

View File

@ -230,6 +230,18 @@ class Cell(Cell_):
raise TypeError("'parallel_parameter_name_list' must be list type.")
self._parallel_parameter_name_list = value
@property
def pipeline_stage(self):
return self._pipeline_stage
@pipeline_stage.setter
def pipeline_stage(self, value):
if not isinstance(value, int):
raise TypeError("'pipeline_stage' must be int type.")
self._pipeline_stage = value
for item in self.trainable_params():
item.add_pipeline_stage(value)
@property
def parallel_parameter_merge_net_dict(self):
return self._parallel_parameter_merge_net_dict
@ -1297,6 +1309,41 @@ class Cell(Cell_):
for cell in self.cells():
cell.recompute(mode, True)
def infer_param_pipeline_stage(self):
"""
Infer pipeline stages of all parameters in the cell.
Notes:
- If a parameter does not belong to any cell which has been set pipeline_stage,
the parameter should use add_pipeline_stage to add it's pipeline_stage information.
- If a parameter P has been used by two operator in different stages "stageA" and "stageB",
the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB)
to add it's stage information before use infer_param_pipeline_stage.
Returns:
The params belong to current stage in pipeline parallel.
Raises:
RuntimeError: If there is a parameter does not belong to any stage.
"""
from mindspore.communication import get_group_size, get_rank
stage_num = context.get_auto_parallel_context("pipeline_stages")
device_num = get_group_size()
rank_id = get_rank()
per_stage_devices = device_num // stage_num
current_stage = rank_id // per_stage_devices
params = []
for param in self.trainable_params():
if not param._pipeline_stage_list:
raise RuntimeError("The parameter {} does not belong to any stage, "
"please check whether the cell where the param locates"
" has been set pipeline_stage. "
"Otherwise, the parameter should use add_pipeline_stage "
"to add its stage information".format(param.name))
if current_stage in param._pipeline_stage_list:
params.append(param)
return params
class GraphKernel(Cell):
"""

View File

@ -302,6 +302,32 @@ class EmbeddingLookup(nn.Cell):
return output, self.embedding_table
class EmbeddingLookupPipeline(nn.Cell):
"""
The embedding lookup table for vocabulary
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
Returns:
output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size)
self.embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
super(EmbeddingLookupPipeline, self).__init__()
self.vocab_size = config.vocab_size
self.embedding_size = config.embedding_size
if config.word_emb_dp:
self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1)))
else:
self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1)))
self.gather.add_prim_attr("parameter_start", 0)
self.shape = (-1, config.seq_length, config.embedding_size)
def construct(self, input_ids, table):
output = self.gather(table, input_ids, 0)
return output
class Attention(nn.Cell):
"""
Self-Attention module for each layer
@ -593,6 +619,44 @@ class Block(nn.Cell):
output = self.add(x, mlp_logit)
return output, layer_present
class PanguAlpha_EmbeddingPipeLine(nn.Cell):
"""
PanguAlpha_EmbeddingPipeLine
"""
def __init__(self, config):
super(PanguAlpha_EmbeddingPipeLine, self).__init__()
self.word_embedding = EmbeddingLookupPipeline(config)
self.position_embedding = nn.Embedding(config.seq_length,
config.embedding_size,
embedding_table=Normal(0.02))
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
self.position_embedding.expand.shard(((config.dp, 1),))
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
def construct(self, input_ids, table, input_position):
input_embedding = self.word_embedding(input_ids, table)
position_embedding = self.position_embedding(input_position)
hidden_states = self.add(input_embedding, position_embedding)
hidden_states = self.dropout(hidden_states)
hidden_states = P.Cast()(hidden_states, mstype.float16)
return hidden_states
class PanguAlpha_Mask(nn.Cell):
"""
PanguAlpha_Mask
"""
def __init__(self, config):
super(PanguAlpha_Mask, self).__init__()
self.get_attention_mask = AttentionMask(config)
self.dtype = config.compute_dtype
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
def construct(self, input_mask, attention_mask):
attention_mask = self.expand_dims(attention_mask, 1)
return attention_mask
class QueryLayerAttention(Attention):
r"""
@ -828,6 +892,74 @@ class PanguAlpha_Model(nn.Cell):
present_layer = present_layer + (present,)
return output_state, present_layer, embedding_table
class PanguAlpha_ModelPipeline(nn.Cell):
"""
The backbone of PanguAlpha network
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs with datatype int32
input_mask: the mask indicating whether each position is a valid input
layer_past: the previous feature map
Returns:
output_state: Tensor, the output logit of backbone
present_layer: Tensor, the current feature map
embedding_table: Tensor, the embedding table for the vocabulary
"""
def __init__(self, config):
super(PanguAlpha_ModelPipeline, self).__init__()
self.pangu_alpha_embedding = PanguAlpha_EmbeddingPipeLine(config).set_comm_fusion(1)
self.pangu_alpha_embedding.pipeline_stage = 0
self.pangu_alpha_mask = PanguAlpha_Mask(config)
self.blocks = nn.CellList()
self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
embedding_table=TruncatedNormal(0.02))
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
self.top_query_embedding.expand.shard(((config.dp, 1),))
for i in range(config.num_layers):
if i == config.num_layers - 1:
self.top_query_embedding.set_comm_fusion(2)
self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers
per_block = QueryLayer(config).set_comm_fusion(2)
else:
per_block = Block(config, i + 1).set_comm_fusion(2)
per_block.pipeline_stage = i * config.stage_num // config.num_layers
per_block.recompute()
self.blocks.append(per_block)
if config.self_layernorm:
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
else:
self.layernorm = nn.LayerNorm(
(config.embedding_size,)).to_float(mstype.float32)
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
self.layernorm.set_comm_fusion(2)
self.layernorm.pipeline_stage = config.stage_num - 1
self.use_past = config.use_past
self.past = tuple([None] * config.num_layers)
self.dtype = config.compute_dtype
self.num_layers = config.num_layers
def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None):
"""PanguAlpha model"""
if not self.use_past:
layer_past = self.past
hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position)
attention_mask = self.pangu_alpha_mask(input_mask, attention_mask)
present_layer = ()
for i in range(self.num_layers-1):
hidden_states, present = self.blocks[i](hidden_states,
attention_mask, layer_past)
present_layer = present_layer + (present,)
top_query_hidden_states = self.top_query_embedding(input_position)
hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
attention_mask, layer_past)
present_layer = present_layer + (present,)
output_state = self.layernorm(hidden_states)
output_state = F.cast(output_state, self.dtype)
return output_state, present_layer
class PanguAlpha_Head(nn.Cell):
"""
@ -883,6 +1015,35 @@ class PanguAlpha(nn.Cell):
logits = self.head(output_states, embedding_table)
return logits
class PanguAlphaPipeline(nn.Cell):
"""
The PanguAlpha network consisting of two parts the backbone and the head
Args:
config(PanguAlphaConfig): the config of network
Inputs:
input_ids: the tokenized inputs
input_mask: the mask indicating whether each position is a valid input
past: the previous feature map
Returns:
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
"""
def __init__(self, config):
super(PanguAlphaPipeline, self).__init__()
self.backbone = PanguAlpha_ModelPipeline(config)
self.head = PanguAlpha_Head(config)
self.head.pipeline_stage = config.stage_num - 1
self.vocab_size = config.vocab_size
self.embedding_size = config.embedding_size
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
name="embedding_table")
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
def construct(self, input_ids, input_mask, input_position, attention_mask, past=None):
output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table,
input_position, attention_mask, past)
logits = self.head(output_states, self.embedding_table)
return logits
class CrossEntropyLoss(nn.Cell):
"""
@ -1010,6 +1171,38 @@ class PanguAlphaWithLoss(nn.Cell):
output = self.loss(logits, labels, input_mask)
return output
class PanguAlphaWithLossPipeline(nn.Cell):
"""
PanguAlpha training loss
Args:
network: backbone network of PanguAlpha
loss: loss function, e.g., crossentropy
eos_token: the end_of_sentence token
Inputs:
input_ids: the tokenized inputs
past: the previous feature map
Returns:
output: Tensor, the loss of the network
"""
def __init__(self, config, network, loss, eos_token=6):
super(PanguAlphaWithLossPipeline, self).__init__(auto_prefix=False)
self.network = network
self.loss = loss
self.eos_token = eos_token
self.slice = P.StridedSlice().shard(((config.dp, 1),))
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
self.batch_size = config.batch_size
self.len = config.seq_length
self.micro_batch_step = config.micro_size
def construct(self, input_ids, input_position, attention_mask):
tokens = self.slice(input_ids, (0, 0), (self.batch_size // self.micro_batch_step, -1), (1, 1))
input_mask = F.cast(self.not_equal(tokens, self.eos_token), mstype.float32)
logits = self.network(tokens, input_mask, input_position, attention_mask)
labels = self.slice(input_ids, (0, 1), (self.batch_size // self.micro_batch_step,
self.len + 1), (1, 1))
output = self.loss(logits, labels, input_mask)
return output
class EvalNet(nn.Cell):
"""

View File

@ -89,20 +89,27 @@ def set_parse(args_opt):
args_opt.embedding_size = 16384
args_opt.num_layers = 64
args_opt.num_heads = 128
args_opt.per_batch_size = 1
args_opt.word_emb_dp = 0
if args_opt.run_type == "train":
args_opt.start_lr = 6e-5
args_opt.end_lr = 6e-6
args_opt.optimizer_shard = 0
args_opt.stage_num = 16
args_opt.micro_size = 32
args_opt.op_level_model_parallel_num = 16
if args_opt.optimizer_shard == 1:
args_opt.op_level_model_parallel_num = 8
elif args_opt.run_type == "predict":
args_opt.stage_num = 4
args_opt.micro_size = 1
args_opt.op_level_model_parallel_num = 16
if args_opt.optimizer_shard == 1:
args_opt.op_level_model_parallel_num = 8
elif args_opt.mode == "13B":
args_opt.embedding_size = 5120
args_opt.num_layers = 40
args_opt.num_heads = 40
args_opt.word_emb_dp = 1
args_opt.op_level_model_parallel_num = 8
if args_opt.run_type == "train":
args_opt.start_lr = 5e-5

View File

@ -20,8 +20,11 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops.operations.comm_ops import _VirtualDataset
from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell
from mindspore import context, Parameter
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.communication.management import get_group_size
from src.utils import ClipByGlobalNorm
GRADIENT_CLIP_TYPE = 1
@ -65,16 +68,14 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class VirtualDatasetOneInputCell(nn.Cell):
def __init__(self, backbone):
super(VirtualDatasetOneInputCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._virtual_dataset = _VirtualDataset()
def construct(self, *data):
data_ = self._virtual_dataset(*data)
return self._backbone(*data_)
@grad_scale.register("Tensor", "Tensor", "Tensor")
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
accu_grad = F.depend(accu_grad, grad)
new_grad = accu_grad * reciprocal(scale)
accu_grad = F.depend(accu_grad, new_grad)
zeros = F.tensor_mul(accu_grad, 0.0)
_ = F.assign(accu_grad, zeros)
return new_grad
class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
"""
@ -102,7 +103,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
self.optimizer = optimizer
self.default_lr = Tensor([0.0], dtype=mstype.float32)
self.enable_global_norm = enable_global_norm
self.clip = ClipByGlobalNorm(self.weights)
self.clip = ClipByGlobalNorm(self.weights, config)
self.cast = P.Cast()
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
@ -142,3 +143,111 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
else:
succ = self.optimizer(grads)
return F.depend(loss, succ), cond, scaling_sens
class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
"""
Encapsulation class of PanguAlpha network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
"""
def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True):
super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__(auto_prefix=False)
self.config = config
self.network = network
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
self.optimizer = optimizer
self.enable_global_norm = enable_global_norm
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.reducer_flag = False
self.allreduce = P.AllReduce()
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False)
self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False)
self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False)
self.reduce_sum = P.ReduceSum(keep_dims=False)
#self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.reshape = P.Reshape()
#self.control = P.ControlDepend(1)
#self.clip_norm = Tensor(1000.0, mstype.float32)
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
name="loss_scale")
self.clip = ClipByGlobalNorm(self.weights, self.config)
self.micro_size = config.micro_size
@C.add_flags(has_effect=True)
def construct(self,
input_ids,
input_position,
attention_mask,
past=None,
sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(input_ids, input_position, attention_mask)
if sens is None:
scaling_sens = self.loss_scale
scaling_sens = self.reshape(scaling_sens, (1,))
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
status_clear = self.clear_before_grad(init)
#clear_depend = self.control(status_clear, self.weights)
grads = self.grad(self.network, weights)(input_ids,
input_position,
attention_mask,
self.cast(scaling_sens / self.micro_size,
mstype.float32))
init = F.depend(init, grads)
get_status = self.get_status(init)
init = F.depend(init, get_status)
flag_sum = self.reduce_sum(init, (0,))
loss = F.depend(loss, status_clear)
# apply grad reducer on grads
accu_grads = self.grad_reducer(self.accu_grads)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
if self.enable_global_norm:
grads, _ = self.clip(grads)
else:
grads = self.hyper_map(
F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
grads)
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce)
else:
cond = self.less_equal(self.base, flag_sum)
overflow = cond
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, overflow, scaling_sens)
return F.depend(ret, succ)

View File

@ -26,8 +26,8 @@ from mindspore.ops import functional as F
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR
from mindspore.parallel._utils import _get_global_rank
from mindspore.communication.management import get_group_size
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.communication.management import get_rank, get_group_size, create_group
from mindspore.nn import AdamWeightDecay
from mindspore.common import Parameter, ParameterTuple
from mindspore.common.initializer import initializer
@ -71,14 +71,12 @@ class FP32StateAdamWeightDecay(AdamWeightDecay):
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor", "Tensor")
@get_square_sum.register("Tensor", "Number")
def _get_square_sum(grad, value):
norm = P.ReduceSum(False)(F.square(grad) / value, ())
norm = P.ReduceSum(False)(F.square(grad), ()) / value
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
return norm
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@ -87,6 +85,41 @@ def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
return grad
def _get_model_parallel_group(mp):
"""
Calculate the communication group of model parallel dim in one pipeline stage
"""
rank = get_rank()
stage_nums = auto_parallel_context().get_pipeline_stages()
device_nums = get_group_size()
per_stage_device_nums = device_nums // stage_nums
stage_id = rank // per_stage_device_nums
local_stage_rank_id = rank % per_stage_device_nums
index = local_stage_rank_id // mp
group = range(0, mp)
rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group]
rank_list_str = "-".join(rank_str_list)
rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group]
return rank_list, rank_list_str
def _get_pipeline_group():
"""
Calculate the communication group between all pipeline stages
"""
rank = get_rank()
stage_nums = auto_parallel_context().get_pipeline_stages()
device_nums = get_group_size()
per_stage_device_nums = device_nums // stage_nums
local_stage_rank_id = rank % per_stage_device_nums
group = range(0, stage_nums)
rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group]
rank_str_list = [str(local_stage_rank_id + x * per_stage_device_nums) for x in group]
rank_list_str = "-".join(rank_str_list)
return rank_list, rank_list_str
class GlobalNorm(nn.Cell):
"""
@ -107,9 +140,9 @@ class GlobalNorm(nn.Cell):
self.group_size = get_group_size()
for item in self.allreduce_filter:
if item:
self.values.append(Tensor([1.0], mstype.float32))
self.values.append(1.0)
else:
self.values.append(Tensor([self.group_size * 1.0], mstype.float32))
self.values.append(self.group_size * 1.0)
self.values = tuple(self.values)
def construct(self, grads):
@ -119,6 +152,37 @@ class GlobalNorm(nn.Cell):
global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
return global_norms
class GlobalNormPipline(nn.Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self, params, config):
super(GlobalNormPipline, self).__init__()
self.norm = nn.Norm()
self.hyper_map = C.HyperMap()
self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name
and "position_embedding.embedding_table" not in x.name for x in params)
self.allreduce_group_size = ()
for item in self.allreduce_filter:
if item:
self.allreduce_group_size = self.allreduce_group_size + (1.0,)
else:
self.allreduce_group_size = self.allreduce_group_size + (config.mp * 1.0,)
self.length = len(params)
group_list, group_name = _get_model_parallel_group(config.mp)
create_group(group_name, group_list)
self.allreduce = P.AllReduce(group=group_name)
pipeline_group_list, pipeline_group_name = _get_pipeline_group()
create_group(pipeline_group_name, pipeline_group_list)
self.allreduce2 = P.AllReduce(group=pipeline_group_name)
def construct(self, grads):
square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size)
square_reduce_sum = F.addn(square_sum)
stage_square_reduce_sum = self.allreduce(square_reduce_sum)
global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum)
global_norms = F.sqrt(global_square_reduce_sum)
return global_norms
class ClipByGlobalNorm(nn.Cell):
"""
@ -126,10 +190,12 @@ class ClipByGlobalNorm(nn.Cell):
Clip grads by global norm
"""
def __init__(self, params, clip_norm=1.0):
def __init__(self, params, config, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm(params)
if config.stage_num > 1:
self.global_norm = GlobalNormPipline(params, config)
else:
self.global_norm = GlobalNorm(params)
self.clip_norm = Tensor([clip_norm], mstype.float32)
self.hyper_map = C.HyperMap()
@ -140,14 +206,6 @@ class ClipByGlobalNorm(nn.Cell):
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
return grads, global_norm_value
def _get_model_parallel_group(dp, mp):
rank = _get_global_rank()
group = range(0, mp)
index = rank // dp
return [x + index * mp for x in group]
class LearningRate(LearningRateSchedule):
"""
Warmup-decay learning rate for PanguAlpha network.
@ -258,6 +316,10 @@ def add_training_params(opt):
type=int,
default=2000,
help="Warmup step, default is 2000.")
opt.add_argument("--decay_steps",
type=int,
default=200000,
help="Decay step, default is 200000.")
opt.add_argument("--optimizer",
type=str,
default="adam",
@ -298,6 +360,11 @@ def add_training_params(opt):
type=int,
default=8,
help="The model parallel way. default 8")
opt.add_argument("--word_emb_dp",
type=int,
default=1,
choices=[0, 1],
help="Whether do data parallel in word embedding. default 1")
def get_args(inference=False):

View File

@ -29,10 +29,11 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
import mindspore.common.dtype as mstype
from mindspore.parallel import set_algo_parameters
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
from src.dataset import create_dataset
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss,\
PanguAlphaPipeline, PanguAlphaWithLossPipeline, CrossEntropyLoss
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
from src.utils import download_data
@ -132,14 +133,14 @@ def run_train(args_opt):
micro_size=args_opt.micro_size,
eod_reset=bool(args_opt.eod_reset),
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
word_emb_dp=True)
word_emb_dp=bool(args_opt.word_emb_dp))
print("===config is: ", config, flush=True)
# Define network
pangu_alpha = PanguAlpha(config)
loss = CrossEntropyLoss(config)
pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss)
pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss)
print("=====args_opt is: ", args_opt, flush=True)
@ -189,8 +190,111 @@ def run_train(args_opt):
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)
def run_train_pipeline(args_opt):
r"""
The main training process in pipeline.
"""
device_id = int(os.getenv("DEVICE_ID"))
context.set_context(save_graphs=False,
mode=context.GRAPH_MODE,
device_target="Ascend",
device_id=device_id)
context.set_context(variable_memory_max_size="31GB")
if args_opt.distribute == "true":
D.init()
device_num = D.get_group_size()
rank_id = D.get_rank()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False,
device_num=device_num,
full_batch=True,
loss_repeated_mean=True,
enable_parallel_optimizer=bool(args_opt.optimizer_shard),
pipeline_stages=args_opt.stage_num)
set_algo_parameters(elementwise_op_strategy_follow=True)
_set_multi_subgraphs()
else:
rank_id = int(os.getenv("RANK_ID"))
device_num = 1
model_parallel_num = args_opt.op_level_model_parallel_num
stage_device_num = int(device_num / args_opt.stage_num)
data_parallel_num = int(stage_device_num / model_parallel_num)
per_batch_size = args_opt.per_batch_size
batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
config = PANGUALPHAConfig(
data_parallel_num=data_parallel_num,
model_parallel_num=model_parallel_num,
batch_size=batch_size,
seq_length=args_opt.seq_length,
vocab_size=args_opt.vocab_size,
embedding_size=args_opt.embedding_size,
num_layers=args_opt.num_layers,
num_heads=args_opt.num_heads,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.1,
compute_dtype=mstype.float16,
use_past=False,
self_layernorm=True,
stage_num=args_opt.stage_num,
micro_size=args_opt.micro_size,
word_emb_dp=bool(args_opt.word_emb_dp))
print("===config is: ", config, flush=True)
pangu_alpha = PanguAlphaPipeline(config)
loss = CrossEntropyLoss(config)
pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLossPipeline(config, pangu_alpha, loss), config.micro_size)
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss)
print("=====args_opt is: ", args_opt, flush=True)
lr = LearningRate(learning_rate=args_opt.start_lr,
end_learning_rate=args_opt.end_lr,
warmup_steps=args_opt.warmup_step,
decay_steps=args_opt.decay_steps)
params = pangu_alpha.infer_param_pipeline_stage()
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
decay_params = list(filter(decay_filter, params))
other_params = list(filter(lambda x: not decay_filter(x), params))
group_params = [{
'params': decay_params,
'weight_decay': 1e-1
}, {
'params': other_params,
'weight_decay': 0.0
}, {
'order_params': params
}]
if args_opt.optimizer == "lamb":
optimizer = nn.Lamb(group_params, learning_rate=lr)
else:
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True,
data_start_index=0, full_batch=True)
epoch_num = args_opt.epoch_size
step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
callback = [
TimeMonitor(callback_size),
LossCallBack(callback_size, rank_id, config.stage_num)
]
loss_scale_value = math.pow(2, 32)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
scale_factor=2,
scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
model = Model(pangu_alpha_with_grads)
model.train(actual_epoch_num,
ds,
callbacks=callback,
sink_size=callback_size,
dataset_sink_mode=True)
if __name__ == "__main__":
opt = get_args()
set_parse(opt)
run_train(opt)
if opt.stage_num > 1:
run_train_pipeline(opt)
else:
run_train(opt)