diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc index bc9847dfdb9..3309f18d5ec 100755 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -225,8 +225,9 @@ bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) { if (rank_tag2 == nullptr) { rank_tag2 = prim2->GetAttr(DEST_RANK); } - MS_EXCEPTION_IF_NULL(rank_tag1); - MS_EXCEPTION_IF_NULL(rank_tag2); + if (!rank_tag1 || !rank_tag2) { + return false; + } auto rank1_value = GetValue(rank_tag1); auto rank2_value = GetValue(rank_tag2); if (rank1_value == rank2_value) { diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index e27444b0763..450bc261218 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags"; // define the parse constant const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1; const char CUSTOM_BPROP_NAME[] = "bprop"; -const char STAGE_NAME[] = "pipeline_stage"; +const char STAGE_NAME[] = "_pipeline_stage"; // define the Namespace name const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 08976ed37fa..fd0b12dcc19 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -154,6 +154,7 @@ class Parameter(Tensor_): self._cast_type = None self._unique = False self.is_in_parallel = _is_in_parallel_mode() + self._pipeline_stage_list = [] if isinstance(default_input, (Tensor_, Tensor)): Tensor_.__init__(self, default_input.dtype, default_input.shape) elif isinstance(default_input, int): @@ -452,6 +453,9 @@ class Parameter(Tensor_): new_param.param_info = self.param_info return new_param + def add_pipeline_stage(self, stage): + self._pipeline_stage_list.append(stage) + def set_data(self, data, slice_shape=False): """ Set Parameter's data. diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 7256aec97bf..1ae544ab911 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -230,6 +230,18 @@ class Cell(Cell_): raise TypeError("'parallel_parameter_name_list' must be list type.") self._parallel_parameter_name_list = value + @property + def pipeline_stage(self): + return self._pipeline_stage + + @pipeline_stage.setter + def pipeline_stage(self, value): + if not isinstance(value, int): + raise TypeError("'pipeline_stage' must be int type.") + self._pipeline_stage = value + for item in self.trainable_params(): + item.add_pipeline_stage(value) + @property def parallel_parameter_merge_net_dict(self): return self._parallel_parameter_merge_net_dict @@ -1297,6 +1309,41 @@ class Cell(Cell_): for cell in self.cells(): cell.recompute(mode, True) + def infer_param_pipeline_stage(self): + """ + Infer pipeline stages of all parameters in the cell. + + Notes: + - If a parameter does not belong to any cell which has been set pipeline_stage, + the parameter should use add_pipeline_stage to add it's pipeline_stage information. + - If a parameter P has been used by two operator in different stages "stageA" and "stageB", + the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB) + to add it's stage information before use infer_param_pipeline_stage. + + Returns: + The params belong to current stage in pipeline parallel. + + Raises: + RuntimeError: If there is a parameter does not belong to any stage. + """ + from mindspore.communication import get_group_size, get_rank + stage_num = context.get_auto_parallel_context("pipeline_stages") + device_num = get_group_size() + rank_id = get_rank() + per_stage_devices = device_num // stage_num + current_stage = rank_id // per_stage_devices + params = [] + for param in self.trainable_params(): + if not param._pipeline_stage_list: + raise RuntimeError("The parameter {} does not belong to any stage, " + "please check whether the cell where the param locates" + " has been set pipeline_stage. " + "Otherwise, the parameter should use add_pipeline_stage " + "to add its stage information".format(param.name)) + if current_stage in param._pipeline_stage_list: + params.append(param) + return params + class GraphKernel(Cell): """ diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py index 7ff1679eeba..f19102b1254 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -302,6 +302,32 @@ class EmbeddingLookup(nn.Cell): return output, self.embedding_table +class EmbeddingLookupPipeline(nn.Cell): + """ + The embedding lookup table for vocabulary + Args: + config(PanguAlphaConfig): the config of network + Inputs: + input_ids: the tokenized inputs with datatype int32 + Returns: + output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size) + self.embedding_table: Tensor, the embedding table for the vocabulary + """ + def __init__(self, config): + super(EmbeddingLookupPipeline, self).__init__() + self.vocab_size = config.vocab_size + self.embedding_size = config.embedding_size + if config.word_emb_dp: + self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1))) + else: + self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1))) + self.gather.add_prim_attr("parameter_start", 0) + self.shape = (-1, config.seq_length, config.embedding_size) + + def construct(self, input_ids, table): + output = self.gather(table, input_ids, 0) + return output + class Attention(nn.Cell): """ Self-Attention module for each layer @@ -593,6 +619,44 @@ class Block(nn.Cell): output = self.add(x, mlp_logit) return output, layer_present +class PanguAlpha_EmbeddingPipeLine(nn.Cell): + """ + PanguAlpha_EmbeddingPipeLine + """ + def __init__(self, config): + super(PanguAlpha_EmbeddingPipeLine, self).__init__() + self.word_embedding = EmbeddingLookupPipeline(config) + self.position_embedding = nn.Embedding(config.seq_length, + config.embedding_size, + embedding_table=Normal(0.02)) + self.position_embedding.gather.shard(((1, 1), (config.dp,))) + self.position_embedding.expand.shard(((config.dp, 1),)) + self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) + self.dropout = Dropout(1 - config.dropout_rate) + self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) + self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) + + def construct(self, input_ids, table, input_position): + input_embedding = self.word_embedding(input_ids, table) + position_embedding = self.position_embedding(input_position) + hidden_states = self.add(input_embedding, position_embedding) + hidden_states = self.dropout(hidden_states) + hidden_states = P.Cast()(hidden_states, mstype.float16) + return hidden_states + + +class PanguAlpha_Mask(nn.Cell): + """ + PanguAlpha_Mask + """ + def __init__(self, config): + super(PanguAlpha_Mask, self).__init__() + self.get_attention_mask = AttentionMask(config) + self.dtype = config.compute_dtype + self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) + def construct(self, input_mask, attention_mask): + attention_mask = self.expand_dims(attention_mask, 1) + return attention_mask class QueryLayerAttention(Attention): r""" @@ -828,6 +892,74 @@ class PanguAlpha_Model(nn.Cell): present_layer = present_layer + (present,) return output_state, present_layer, embedding_table +class PanguAlpha_ModelPipeline(nn.Cell): + """ + The backbone of PanguAlpha network + Args: + config(PanguAlphaConfig): the config of network + Inputs: + input_ids: the tokenized inputs with datatype int32 + input_mask: the mask indicating whether each position is a valid input + layer_past: the previous feature map + Returns: + output_state: Tensor, the output logit of backbone + present_layer: Tensor, the current feature map + embedding_table: Tensor, the embedding table for the vocabulary + """ + def __init__(self, config): + super(PanguAlpha_ModelPipeline, self).__init__() + self.pangu_alpha_embedding = PanguAlpha_EmbeddingPipeLine(config).set_comm_fusion(1) + self.pangu_alpha_embedding.pipeline_stage = 0 + self.pangu_alpha_mask = PanguAlpha_Mask(config) + self.blocks = nn.CellList() + self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size, + embedding_table=TruncatedNormal(0.02)) + self.top_query_embedding.gather.shard(((1, 1), (config.dp,))) + self.top_query_embedding.expand.shard(((config.dp, 1),)) + for i in range(config.num_layers): + if i == config.num_layers - 1: + self.top_query_embedding.set_comm_fusion(2) + self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers + per_block = QueryLayer(config).set_comm_fusion(2) + else: + per_block = Block(config, i + 1).set_comm_fusion(2) + per_block.pipeline_stage = i * config.stage_num // config.num_layers + per_block.recompute() + self.blocks.append(per_block) + + if config.self_layernorm: + self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32) + else: + self.layernorm = nn.LayerNorm( + (config.embedding_size,)).to_float(mstype.float32) + self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,))) + self.layernorm.set_comm_fusion(2) + self.layernorm.pipeline_stage = config.stage_num - 1 + self.use_past = config.use_past + self.past = tuple([None] * config.num_layers) + self.dtype = config.compute_dtype + self.num_layers = config.num_layers + + def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None): + """PanguAlpha model""" + if not self.use_past: + layer_past = self.past + + hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position) + attention_mask = self.pangu_alpha_mask(input_mask, attention_mask) + + present_layer = () + for i in range(self.num_layers-1): + hidden_states, present = self.blocks[i](hidden_states, + attention_mask, layer_past) + present_layer = present_layer + (present,) + top_query_hidden_states = self.top_query_embedding(input_position) + hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states, + attention_mask, layer_past) + present_layer = present_layer + (present,) + output_state = self.layernorm(hidden_states) + output_state = F.cast(output_state, self.dtype) + return output_state, present_layer class PanguAlpha_Head(nn.Cell): """ @@ -883,6 +1015,35 @@ class PanguAlpha(nn.Cell): logits = self.head(output_states, embedding_table) return logits +class PanguAlphaPipeline(nn.Cell): + """ + The PanguAlpha network consisting of two parts the backbone and the head + Args: + config(PanguAlphaConfig): the config of network + Inputs: + input_ids: the tokenized inputs + input_mask: the mask indicating whether each position is a valid input + past: the previous feature map + Returns: + logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size) + """ + def __init__(self, config): + super(PanguAlphaPipeline, self).__init__() + self.backbone = PanguAlpha_ModelPipeline(config) + self.head = PanguAlpha_Head(config) + self.head.pipeline_stage = config.stage_num - 1 + self.vocab_size = config.vocab_size + self.embedding_size = config.embedding_size + self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]), + name="embedding_table") + self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage) + self.embedding_table.add_pipeline_stage(self.head.pipeline_stage) + + def construct(self, input_ids, input_mask, input_position, attention_mask, past=None): + output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table, + input_position, attention_mask, past) + logits = self.head(output_states, self.embedding_table) + return logits class CrossEntropyLoss(nn.Cell): """ @@ -1010,6 +1171,38 @@ class PanguAlphaWithLoss(nn.Cell): output = self.loss(logits, labels, input_mask) return output +class PanguAlphaWithLossPipeline(nn.Cell): + """ + PanguAlpha training loss + Args: + network: backbone network of PanguAlpha + loss: loss function, e.g., crossentropy + eos_token: the end_of_sentence token + Inputs: + input_ids: the tokenized inputs + past: the previous feature map + Returns: + output: Tensor, the loss of the network + """ + def __init__(self, config, network, loss, eos_token=6): + super(PanguAlphaWithLossPipeline, self).__init__(auto_prefix=False) + self.network = network + self.loss = loss + self.eos_token = eos_token + self.slice = P.StridedSlice().shard(((config.dp, 1),)) + self.not_equal = P.NotEqual().shard(((config.dp, 1), ())) + self.batch_size = config.batch_size + self.len = config.seq_length + self.micro_batch_step = config.micro_size + + def construct(self, input_ids, input_position, attention_mask): + tokens = self.slice(input_ids, (0, 0), (self.batch_size // self.micro_batch_step, -1), (1, 1)) + input_mask = F.cast(self.not_equal(tokens, self.eos_token), mstype.float32) + logits = self.network(tokens, input_mask, input_position, attention_mask) + labels = self.slice(input_ids, (0, 1), (self.batch_size // self.micro_batch_step, + self.len + 1), (1, 1)) + output = self.loss(logits, labels, input_mask) + return output class EvalNet(nn.Cell): """ diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py index ad255cfa9df..d9489b09e04 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py @@ -89,20 +89,27 @@ def set_parse(args_opt): args_opt.embedding_size = 16384 args_opt.num_layers = 64 args_opt.num_heads = 128 + args_opt.per_batch_size = 1 + args_opt.word_emb_dp = 0 if args_opt.run_type == "train": args_opt.start_lr = 6e-5 args_opt.end_lr = 6e-6 - args_opt.optimizer_shard = 0 args_opt.stage_num = 16 args_opt.micro_size = 32 args_opt.op_level_model_parallel_num = 16 + if args_opt.optimizer_shard == 1: + args_opt.op_level_model_parallel_num = 8 elif args_opt.run_type == "predict": args_opt.stage_num = 4 args_opt.micro_size = 1 + args_opt.op_level_model_parallel_num = 16 + if args_opt.optimizer_shard == 1: + args_opt.op_level_model_parallel_num = 8 elif args_opt.mode == "13B": args_opt.embedding_size = 5120 args_opt.num_layers = 40 args_opt.num_heads = 40 + args_opt.word_emb_dp = 1 args_opt.op_level_model_parallel_num = 8 if args_opt.run_type == "train": args_opt.start_lr = 5e-5 diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py index 4f34123552b..2fb79399f1c 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py @@ -20,8 +20,11 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype -from mindspore.ops.operations.comm_ops import _VirtualDataset from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell +from mindspore import context, Parameter +from mindspore.context import ParallelMode +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.communication.management import get_group_size from src.utils import ClipByGlobalNorm GRADIENT_CLIP_TYPE = 1 @@ -65,16 +68,14 @@ reciprocal = P.Reciprocal() def tensor_grad_scale(scale, grad): return grad * reciprocal(scale) - -class VirtualDatasetOneInputCell(nn.Cell): - def __init__(self, backbone): - super(VirtualDatasetOneInputCell, self).__init__(auto_prefix=False) - self._backbone = backbone - self._virtual_dataset = _VirtualDataset() - - def construct(self, *data): - data_ = self._virtual_dataset(*data) - return self._backbone(*data_) +@grad_scale.register("Tensor", "Tensor", "Tensor") +def tensor_grad_scale_pipeline(scale, grad, accu_grad): + accu_grad = F.depend(accu_grad, grad) + new_grad = accu_grad * reciprocal(scale) + accu_grad = F.depend(accu_grad, new_grad) + zeros = F.tensor_mul(accu_grad, 0.0) + _ = F.assign(accu_grad, zeros) + return new_grad class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): """ @@ -102,7 +103,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): self.optimizer = optimizer self.default_lr = Tensor([0.0], dtype=mstype.float32) self.enable_global_norm = enable_global_norm - self.clip = ClipByGlobalNorm(self.weights) + self.clip = ClipByGlobalNorm(self.weights, config) self.cast = P.Cast() def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None): @@ -142,3 +143,111 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): else: succ = self.optimizer(grads) return F.depend(loss, succ), cond, scaling_sens + +class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell): + """ + Encapsulation class of PanguAlpha network training. + + Append an optimizer to the training network after that the construct + function can be called to create the backward graph. + + Args: + network (Cell): The training network. Note that loss function should have been added. + optimizer (Optimizer): Optimizer for updating the weights. + scale_update_cell (Cell): Cell to do the loss scale. Default: None. + """ + def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True): + super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__(auto_prefix=False) + self.config = config + self.network = network + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") + self.optimizer = optimizer + self.enable_global_norm = enable_global_norm + self.grad = C.GradOperation(get_by_list=True, + sens_param=True) + self.reducer_flag = False + self.allreduce = P.AllReduce() + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) + self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) + self.cast = P.Cast() + self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False) + self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False) + self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False) + self.reduce_sum = P.ReduceSum(keep_dims=False) + #self.depend_parameter_use = P.ControlDepend(depend_mode=1) + self.base = Tensor(1, mstype.float32) + self.less_equal = P.LessEqual() + self.hyper_map = C.HyperMap() + self.loss_scale = None + self.reshape = P.Reshape() + #self.control = P.ControlDepend(1) + #self.clip_norm = Tensor(1000.0, mstype.float32) + self.loss_scaling_manager = scale_update_cell + if scale_update_cell: + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), + name="loss_scale") + self.clip = ClipByGlobalNorm(self.weights, self.config) + self.micro_size = config.micro_size + + @C.add_flags(has_effect=True) + def construct(self, + input_ids, + input_position, + attention_mask, + past=None, + sens=None): + """Defines the computation performed.""" + weights = self.weights + loss = self.network(input_ids, input_position, attention_mask) + if sens is None: + scaling_sens = self.loss_scale + scaling_sens = self.reshape(scaling_sens, (1,)) + else: + scaling_sens = sens + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + status_clear = self.clear_before_grad(init) + #clear_depend = self.control(status_clear, self.weights) + grads = self.grad(self.network, weights)(input_ids, + input_position, + attention_mask, + self.cast(scaling_sens / self.micro_size, + mstype.float32)) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) + flag_sum = self.reduce_sum(init, (0,)) + loss = F.depend(loss, status_clear) + # apply grad reducer on grads + accu_grads = self.grad_reducer(self.accu_grads) + grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads) + if self.enable_global_norm: + grads, _ = self.clip(grads) + else: + grads = self.hyper_map( + F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), + grads) + if self.is_distributed: + # sum overflow flag over devices + flag_reduce = self.allreduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + overflow = cond + if sens is None: + overflow = self.loss_scaling_manager(self.loss_scale, cond) + if overflow: + succ = False + else: + succ = self.optimizer(grads) + ret = (loss, overflow, scaling_sens) + return F.depend(ret, succ) diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index 6e9a60dcb16..96b44acc83c 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -26,8 +26,8 @@ from mindspore.ops import functional as F import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR -from mindspore.parallel._utils import _get_global_rank -from mindspore.communication.management import get_group_size +from mindspore.parallel._auto_parallel_context import auto_parallel_context +from mindspore.communication.management import get_rank, get_group_size, create_group from mindspore.nn import AdamWeightDecay from mindspore.common import Parameter, ParameterTuple from mindspore.common.initializer import initializer @@ -71,14 +71,12 @@ class FP32StateAdamWeightDecay(AdamWeightDecay): get_square_sum = C.MultitypeFuncGraph("get_square_sum") - -@get_square_sum.register("Tensor", "Tensor") +@get_square_sum.register("Tensor", "Number") def _get_square_sum(grad, value): - norm = P.ReduceSum(False)(F.square(grad) / value, ()) + norm = P.ReduceSum(False)(F.square(grad), ()) / value norm = F.expand_dims(F.cast(norm, mstype.float32), 0) return norm - apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") @@ -87,6 +85,41 @@ def _apply_global_norm(clip_norm, global_norm, grad): grad = grad * clip_norm / global_norm return grad +def _get_model_parallel_group(mp): + """ + + Calculate the communication group of model parallel dim in one pipeline stage + + """ + rank = get_rank() + stage_nums = auto_parallel_context().get_pipeline_stages() + device_nums = get_group_size() + per_stage_device_nums = device_nums // stage_nums + stage_id = rank // per_stage_device_nums + local_stage_rank_id = rank % per_stage_device_nums + index = local_stage_rank_id // mp + group = range(0, mp) + rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group] + rank_list_str = "-".join(rank_str_list) + rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group] + return rank_list, rank_list_str + +def _get_pipeline_group(): + """ + + Calculate the communication group between all pipeline stages + + """ + rank = get_rank() + stage_nums = auto_parallel_context().get_pipeline_stages() + device_nums = get_group_size() + per_stage_device_nums = device_nums // stage_nums + local_stage_rank_id = rank % per_stage_device_nums + group = range(0, stage_nums) + rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group] + rank_str_list = [str(local_stage_rank_id + x * per_stage_device_nums) for x in group] + rank_list_str = "-".join(rank_str_list) + return rank_list, rank_list_str class GlobalNorm(nn.Cell): """ @@ -107,9 +140,9 @@ class GlobalNorm(nn.Cell): self.group_size = get_group_size() for item in self.allreduce_filter: if item: - self.values.append(Tensor([1.0], mstype.float32)) + self.values.append(1.0) else: - self.values.append(Tensor([self.group_size * 1.0], mstype.float32)) + self.values.append(self.group_size * 1.0) self.values = tuple(self.values) def construct(self, grads): @@ -119,6 +152,37 @@ class GlobalNorm(nn.Cell): global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) return global_norms +class GlobalNormPipline(nn.Cell): + """ + Calculate the global norm value of given tensors + """ + def __init__(self, params, config): + super(GlobalNormPipline, self).__init__() + self.norm = nn.Norm() + self.hyper_map = C.HyperMap() + self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name + and "position_embedding.embedding_table" not in x.name for x in params) + self.allreduce_group_size = () + for item in self.allreduce_filter: + if item: + self.allreduce_group_size = self.allreduce_group_size + (1.0,) + else: + self.allreduce_group_size = self.allreduce_group_size + (config.mp * 1.0,) + self.length = len(params) + group_list, group_name = _get_model_parallel_group(config.mp) + create_group(group_name, group_list) + self.allreduce = P.AllReduce(group=group_name) + pipeline_group_list, pipeline_group_name = _get_pipeline_group() + create_group(pipeline_group_name, pipeline_group_list) + self.allreduce2 = P.AllReduce(group=pipeline_group_name) + + def construct(self, grads): + square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size) + square_reduce_sum = F.addn(square_sum) + stage_square_reduce_sum = self.allreduce(square_reduce_sum) + global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum) + global_norms = F.sqrt(global_square_reduce_sum) + return global_norms class ClipByGlobalNorm(nn.Cell): """ @@ -126,10 +190,12 @@ class ClipByGlobalNorm(nn.Cell): Clip grads by global norm """ - - def __init__(self, params, clip_norm=1.0): + def __init__(self, params, config, clip_norm=1.0): super(ClipByGlobalNorm, self).__init__() - self.global_norm = GlobalNorm(params) + if config.stage_num > 1: + self.global_norm = GlobalNormPipline(params, config) + else: + self.global_norm = GlobalNorm(params) self.clip_norm = Tensor([clip_norm], mstype.float32) self.hyper_map = C.HyperMap() @@ -140,14 +206,6 @@ class ClipByGlobalNorm(nn.Cell): grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) return grads, global_norm_value - -def _get_model_parallel_group(dp, mp): - rank = _get_global_rank() - group = range(0, mp) - index = rank // dp - return [x + index * mp for x in group] - - class LearningRate(LearningRateSchedule): """ Warmup-decay learning rate for PanguAlpha network. @@ -258,6 +316,10 @@ def add_training_params(opt): type=int, default=2000, help="Warmup step, default is 2000.") + opt.add_argument("--decay_steps", + type=int, + default=200000, + help="Decay step, default is 200000.") opt.add_argument("--optimizer", type=str, default="adam", @@ -298,6 +360,11 @@ def add_training_params(opt): type=int, default=8, help="The model parallel way. default 8") + opt.add_argument("--word_emb_dp", + type=int, + default=1, + choices=[0, 1], + help="Whether do data parallel in word embedding. default 1") def get_args(inference=False): diff --git a/model_zoo/official/nlp/pangu_alpha/train.py b/model_zoo/official/nlp/pangu_alpha/train.py index ff4257b69dd..9553673cddf 100644 --- a/model_zoo/official/nlp/pangu_alpha/train.py +++ b/model_zoo/official/nlp/pangu_alpha/train.py @@ -29,10 +29,11 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell import mindspore.common.dtype as mstype from mindspore.parallel import set_algo_parameters from mindspore.parallel._cost_model_context import _set_multi_subgraphs - +from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell from src.dataset import create_dataset -from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss -from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell +from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss,\ + PanguAlphaPipeline, PanguAlphaWithLossPipeline, CrossEntropyLoss +from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell from src.pangu_alpha_config import PANGUALPHAConfig, set_parse from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay from src.utils import download_data @@ -132,14 +133,14 @@ def run_train(args_opt): micro_size=args_opt.micro_size, eod_reset=bool(args_opt.eod_reset), param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16, - word_emb_dp=True) + word_emb_dp=bool(args_opt.word_emb_dp)) print("===config is: ", config, flush=True) # Define network pangu_alpha = PanguAlpha(config) loss = CrossEntropyLoss(config) pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss) - pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) + pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss) print("=====args_opt is: ", args_opt, flush=True) @@ -189,8 +190,111 @@ def run_train(args_opt): print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True) model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True) +def run_train_pipeline(args_opt): + r""" + The main training process in pipeline. + """ + device_id = int(os.getenv("DEVICE_ID")) + context.set_context(save_graphs=False, + mode=context.GRAPH_MODE, + device_target="Ascend", + device_id=device_id) + context.set_context(variable_memory_max_size="31GB") + if args_opt.distribute == "true": + D.init() + device_num = D.get_group_size() + rank_id = D.get_rank() + context.reset_auto_parallel_context() + context.set_auto_parallel_context( + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, + gradients_mean=False, + device_num=device_num, + full_batch=True, + loss_repeated_mean=True, + enable_parallel_optimizer=bool(args_opt.optimizer_shard), + pipeline_stages=args_opt.stage_num) + set_algo_parameters(elementwise_op_strategy_follow=True) + _set_multi_subgraphs() + else: + rank_id = int(os.getenv("RANK_ID")) + device_num = 1 + model_parallel_num = args_opt.op_level_model_parallel_num + stage_device_num = int(device_num / args_opt.stage_num) + data_parallel_num = int(stage_device_num / model_parallel_num) + per_batch_size = args_opt.per_batch_size + batch_size = per_batch_size * data_parallel_num * args_opt.micro_size + config = PANGUALPHAConfig( + data_parallel_num=data_parallel_num, + model_parallel_num=model_parallel_num, + batch_size=batch_size, + seq_length=args_opt.seq_length, + vocab_size=args_opt.vocab_size, + embedding_size=args_opt.embedding_size, + num_layers=args_opt.num_layers, + num_heads=args_opt.num_heads, + expand_ratio=4, + post_layernorm_residual=False, + dropout_rate=0.1, + compute_dtype=mstype.float16, + use_past=False, + self_layernorm=True, + stage_num=args_opt.stage_num, + micro_size=args_opt.micro_size, + word_emb_dp=bool(args_opt.word_emb_dp)) + print("===config is: ", config, flush=True) + pangu_alpha = PanguAlphaPipeline(config) + loss = CrossEntropyLoss(config) + pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLossPipeline(config, pangu_alpha, loss), config.micro_size) + pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss) + print("=====args_opt is: ", args_opt, flush=True) + lr = LearningRate(learning_rate=args_opt.start_lr, + end_learning_rate=args_opt.end_lr, + warmup_steps=args_opt.warmup_step, + decay_steps=args_opt.decay_steps) + params = pangu_alpha.infer_param_pipeline_stage() + decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() + decay_params = list(filter(decay_filter, params)) + other_params = list(filter(lambda x: not decay_filter(x), params)) + group_params = [{ + 'params': decay_params, + 'weight_decay': 1e-1 + }, { + 'params': other_params, + 'weight_decay': 0.0 + }, { + 'order_params': params + }] + if args_opt.optimizer == "lamb": + optimizer = nn.Lamb(group_params, learning_rate=lr) + else: + optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) + ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True, + data_start_index=0, full_batch=True) + epoch_num = args_opt.epoch_size + step_per_epoch = ds.get_dataset_size() + callback_size = args_opt.sink_size + actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) + callback = [ + TimeMonitor(callback_size), + LossCallBack(callback_size, rank_id, config.stage_num) + ] + loss_scale_value = math.pow(2, 32) + update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, + scale_factor=2, + scale_window=1000) + pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell( + pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell) + model = Model(pangu_alpha_with_grads) + model.train(actual_epoch_num, + ds, + callbacks=callback, + sink_size=callback_size, + dataset_sink_mode=True) if __name__ == "__main__": opt = get_args() set_parse(opt) - run_train(opt) + if opt.stage_num > 1: + run_train_pipeline(opt) + else: + run_train(opt)