diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 008d4413d4c..9b147ada041 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -553,7 +553,7 @@ class _CellGraphExecutor: """compile graph in auto parallel mode.""" if not auto_parallel_mode: replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) - self._updata_param_node_default_input(phase, replace) + self._update_param_node_default_input(phase, replace) return obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase) @@ -564,13 +564,13 @@ class _CellGraphExecutor: if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"): obj.load_parameter_slice(None) - self._updata_param_node_default_input(phase, replace) + self._update_param_node_default_input(phase, replace) # set parallel inputs in sink mode if is_sink_mode: obj.set_parallel_input_with_inputs(*args) - def _updata_param_node_default_input(self, phase, replace): + def _update_param_node_default_input(self, phase, replace): new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])} return self._graph_executor.updata_param_node_default_input(phase, new_param) diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 297da787c50..4702161cc0c 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -332,15 +332,14 @@ class LeakyReLU(Cell): validator.check_value_type('alpha', alpha, [float, int], self.cls_name) self.greater_equal = P.GreaterEqual() self.mul = P.Mul() - self.maximum = P.Maximum() self.alpha = alpha + self.select_op = P.Maximum() + if self.alpha > 1: + self.select_op = P.Minimum() def construct(self, x): alpha_array = P.Cast()(F.scalar_to_array(self.alpha), P.DType()(x)) - if self.alpha <= 1: - out = self.maximum(alpha_array * x, x) - else: - out = self.maximum(alpha_array * x, x) + out = self.select_op(alpha_array * x, x) return out diff --git a/mindspore/parallel/_dp_allreduce_fusion.py b/mindspore/parallel/_dp_allreduce_fusion.py index 78108595664..8fa4d16c0c4 100644 --- a/mindspore/parallel/_dp_allreduce_fusion.py +++ b/mindspore/parallel/_dp_allreduce_fusion.py @@ -149,9 +149,9 @@ def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"): if not isinstance(data_size, (int, float)): raise TypeError('data_size in data_size_list is invalid') - c_array_sizeList = _c_array(ctypes.c_float, data_size_list) + c_array_size_list = _c_array(ctypes.c_float, data_size_list) c_size_num = ctypes.c_uint(len(data_size_list)) c_group = _c_str(group) - ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList) + ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_size_list) if ret != 0: raise RuntimeError('Allreduce split error') diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index 760b9d53eec..eabf169ccc9 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -30,9 +30,11 @@ def _get_parallel_mode(): """Get parallel mode.""" return auto_parallel_context().get_parallel_mode() + def _is_in_auto_parallel_mode(): return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL] + def _get_full_batch(): """Get whether to use full_batch.""" return auto_parallel_context().get_full_batch() @@ -51,12 +53,8 @@ def _check_task_sink_envs(): """ import os task_sink = os.getenv("GRAPH_OP_RUN") - if task_sink: - try: - if int(task_sink) == 1: - return False - except ValueError: - return True + if task_sink and task_sink.isdigit() and int(task_sink) == 1: + return False return True diff --git a/mindspore/parallel/nn/__init__.py b/mindspore/parallel/nn/__init__.py index a87d6e3ca54..12ce3107429 100644 --- a/mindspore/parallel/nn/__init__.py +++ b/mindspore/parallel/nn/__init__.py @@ -18,9 +18,9 @@ NOTE: This is an experimental interface that is subject to change and/or deletion. """ from .transformer import * -from .layers import * -from .loss import * -from .op_parallel_config import * +from .layers import FixedSparseAttention +from .loss import CrossEntropyLoss +from .op_parallel_config import OpParallelConfig __all__ = [] __all__.extend(transformer.__all__) diff --git a/mindspore/parallel/nn/layers.py b/mindspore/parallel/nn/layers.py index 6a7d5bd7792..dc4a4b4787d 100644 --- a/mindspore/parallel/nn/layers.py +++ b/mindspore/parallel/nn/layers.py @@ -73,6 +73,7 @@ def _valid_type_checks(types, class_name): # as the input of _args_type_validator_check is fixed, so we need to manually change the input order partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) return partial_check(name, type(value)) + return validator_check_func @@ -83,6 +84,7 @@ def _valid_value_checks(types, class_name): # as the input of _args_type_validator_check is fixed, so we need to manually change the input order partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name) return partial_check(name, value) + return validator_check_func @@ -334,7 +336,7 @@ class _Linear(Cell): if self.activation_flag: # some operations has many primitives, need to manually set the shard if self.act_name.lower() == "leakyrelu": - self.activation.maximum.shard((strategy_activation[0], strategy_activation[0])) + self.activation.select_op.shard((strategy_activation[0], strategy_activation[0])) elif self.act_name.lower() == "logsigmoid": self.activation.mul.shard((strategy_activation[0], ())) self.activation.exp.shard(strategy_activation) @@ -402,6 +404,7 @@ class FixedSparseAttention(nn.Cell): >>> print(output.shape) (2, 1024, 512) """ + @_args_type_validator_check(batch_size=Validator.check_positive_int, num_heads=Validator.check_positive_int, size_per_head=Validator.check_positive_int, @@ -437,7 +440,7 @@ class FixedSparseAttention(nn.Cell): self.transpose = P.Transpose().shard(((dp, 1, mp, 1),)) self.batch_matmul = P.BatchMatMul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.multiply = P.Mul().shard(((dp, 1, 1, 1), (1, 1, 1))) - self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32) + self.multiply_data = Tensor([-10000.0], dtype=mstype.float32) self.parallel_config = parallel_config size_per_head_list = [64, 128] if self.seq_length != 1024: @@ -460,7 +463,7 @@ class FixedSparseAttention(nn.Cell): global_mask_original = -10000 * global_mask_original global_mask_fx = global_mask_original.reshape((self.seq_length // 16, 16, self.global_size // 16, 16)) global_mask = np.transpose(global_mask_fx, (2, 0, 1, 3)) - global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :,], self.batch_size, axis=0) + global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :], self.batch_size, axis=0) global_mask = global_mask.reshape((self.batch_size * self.global_size // 16, self.seq_length // 16, 16, 16)) self.global_mask = Tensor(global_mask, mstype.float32) self.local_mask_triangle = Tensor(np.tril(local_ones), mstype.float32) @@ -578,6 +581,7 @@ class _CumSum(Cell): Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. """ + def __init__(self, config): super(_CumSum, self).__init__() dp = config.data_parallel @@ -598,14 +602,13 @@ class _CumSum(Cell): self.delta = Tensor(1, mstype.int32) self.add = P.TensorAdd().shard(((1,), ())) - def construct(self, expert_mask): - # origin_shape: (self.expert_parallel, tokens_per_device, self.expert_dim) + # origin_shape: (expert_parallel, tokens_per_device, self.expert_dim) origin_shape = self.shape(expert_mask) tokens_per_device = origin_shape[1] - # expert_mask_trans's shape: (self.expert_parallel, self.expert_dim, tokens_per_device) + # expert_mask_trans's shape: (expert_parallel, self.expert_dim, tokens_per_device) expert_mask_trans = self.transpose(expert_mask, (0, 2, 1)) - # expert_mask_reshaped's shape: (self.expert_parallel*self.expert_dim, tokens_per_device) + # expert_mask_reshaped's shape: (expert_parallel*self.expert_dim, tokens_per_device) expert_mask_reshaped = self.reshape(expert_mask_trans, (-1, tokens_per_device)) one_dim = self.expand(self.range(self.start, self.add(self.limit, tokens_per_device), self.delta), 0) @@ -614,11 +617,11 @@ class _CumSum(Cell): up_tri_matrix = self.greater(one_dim, other_dim) up_tri_matrix = self.cast(up_tri_matrix, mstype.float32) - # cum_sum's shape: (self.expert_parallel*self.expert_dim, tokens_per_device) + # cum_sum's shape: (expert_parallel*self.expert_dim, tokens_per_device) cum_sum = self.matmul(expert_mask_reshaped, up_tri_matrix) - # cum_sum's shape: (self.expert_parallel, self.expert_dim, tokens_per_device) + # cum_sum's shape: (expert_parallel, self.expert_dim, tokens_per_device) cum_sum = self.reshape(cum_sum, (origin_shape[0], origin_shape[2], tokens_per_device)) - # cum_sum's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) + # cum_sum's shape: (expert_parallel, tokens_per_device, self.expert_dim) cum_sum = self.transpose3(cum_sum, (0, 2, 1)) return cum_sum @@ -646,6 +649,7 @@ class Router(Cell): Outputs: Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`. """ + def __init__(self, d_model, moe_config, @@ -704,6 +708,7 @@ class SwitchRouter(Cell): Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`, Tensor of shape :math:`(1)`. """ + def __init__(self, d_model, moe_config, @@ -752,9 +757,9 @@ class SwitchRouter(Cell): """ Computing the load balance loss. """ - # density_1's shape: (self.expert_parallel, self.expert_dim) + # density_1's shape: (expert_parallel, self.expert_dim) density_1 = self.reduce_mean(expert_mask, 1) - # density_1_proxy's shape: (self.expert_parallel, self.expert_dim) + # density_1_proxy's shape: (expert_parallel, self.expert_dim) density_1_proxy = self.reduce_mean2(router_prob, 1) loss = self.mul(density_1, density_1_proxy) loss = self.reduce_mean3(loss) @@ -766,20 +771,19 @@ class SwitchRouter(Cell): Keeping only the tokens that fit within expert_capacity. """ cumsum = self.cumsum(expert_mask) - # position_in_expert's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) + # position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim) position_in_expert = self.mul4(cumsum, expert_mask) less_result = self.less(position_in_expert, expert_capacity) - # expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) + # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) expert_mask = self.mul5(less_result, expert_mask) - # expert_mask_flat's shape: (self.expert_parallel, tokens_per_device) + # expert_mask_flat's shape: (expert_parallel, tokens_per_device) expert_mask_flat = self.reduce_sum(expert_mask, -1) # Mask out the experts that have overflowed the expert_capacity. - # expert_gate's shape: (self.expert_parallel, tokens_per_device) + # expert_gate's shape: (expert_parallel, tokens_per_device) expert_gate = self.mul6(expert_gate, expert_mask_flat) return expert_gate, expert_mask_flat, position_in_expert - def construct(self, router_logits): router_logits_shape = self.shape(router_logits) router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1])) @@ -791,9 +795,9 @@ class SwitchRouter(Cell): # Probabilities for each token of what expert is should be sent to router_prob = self.softmax(router_logits) - # shape: (self.expert_parallel, tokens_per_device) + # shape is : (expert_parallel, tokens_per_device) expert_index, expert_gate = self.argmax(router_prob) - # expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) + # expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim) expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value) # Computing the load balance loss: @@ -802,12 +806,12 @@ class SwitchRouter(Cell): expert_gate, expert_mask_flat, position_in_expert = \ self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate) - # combine_tensor's shape: (self.expert_parallel, tokens_per_device) + # combine_tensor's shape: (expert_parallel, tokens_per_device) combine_tensor = self.mul7(expert_gate, expert_mask_flat) - # combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim) + # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim) combine_tensor = self.mul8(self.expand(combine_tensor, -1), self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value)) - # combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity) + # combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity) combine_tensor = self.mul9(self.expand2(combine_tensor, -1), self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity, self.on_value, self.off_value)) diff --git a/mindspore/parallel/nn/loss.py b/mindspore/parallel/nn/loss.py index b652a190db4..1f5ad185163 100644 --- a/mindspore/parallel/nn/loss.py +++ b/mindspore/parallel/nn/loss.py @@ -93,7 +93,7 @@ class CrossEntropyLoss(Cell): """ self._check_input(logits, label, input_mask) - # [bs*seq_length, vocab_size] + # the shape is [bs*seq_length, vocab_size] logits = F.cast(logits, mstype.float32) # LogSoftmax for logits over last dimension _, logit_max = self.max(logits) diff --git a/mindspore/parallel/nn/op_parallel_config.py b/mindspore/parallel/nn/op_parallel_config.py index 1a96d2a0fb3..37a2a2a1081 100644 --- a/mindspore/parallel/nn/op_parallel_config.py +++ b/mindspore/parallel/nn/op_parallel_config.py @@ -143,7 +143,6 @@ def _check_config(config): device_num = D.get_group_size() optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer") - # dp * pp * pipeline_stage <= device_num if config.data_parallel * config.model_parallel * pipeline_stage > device_num: raise ValueError(f"The product of the data parallel {config.data_parallel}, " f"model parallel {config.model_parallel} " @@ -155,10 +154,3 @@ def _check_config(config): logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the" f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the " f"optimizer_shard to make them consistent.") - - # pipeline_stage <= micro_batch_num - if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\ - and config.pipeline_stage < config.micro_batch_num: - raise ValueError( - f"The pipeline stage {config.pipeline_stage} should be greater than the micro_batch_num " - f"{config.micro_batch_num}.") diff --git a/mindspore/parallel/nn/transformer.py b/mindspore/parallel/nn/transformer.py index 5196873ae57..1c80f072f6c 100644 --- a/mindspore/parallel/nn/transformer.py +++ b/mindspore/parallel/nn/transformer.py @@ -396,13 +396,14 @@ class FeedForward(Cell): _check_input_shape(F.shape(x), "x", self.cls_name, 3) _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) x = self.cast(x, mstype.float16) - # [bs, seq_length, ffn_hidden_size] + # returned shape is [bs, seq_length, ffn_hidden_size] hidden = self.mapping(x) output = self.projection(hidden) - # [bs, seq_length, hidden_size] + # returned shape is [bs, seq_length, hidden_size] output = self.dropout(output) return output + @constexpr def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim): return math.ceil(k * tokens_per_device * capacity_factor / expert_dim) @@ -588,7 +589,7 @@ class AttentionMask(Cell): mask_right = self.reshape(input_mask, shape_right) attention_mask = self.mul(mask_left, mask_right) lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0) - # [bs, seq_length, seq_length] + # the returned shape is [bs, seq_length, seq_length] attention_mask = self.multiply( attention_mask, lower_traiangle) return attention_mask @@ -889,18 +890,18 @@ class MultiHeadAttention(Cell): query = self.dense1(query_tensor) key = self.dense2(key_tensor) value = self.dense3(value_tensor) - # [bs, num_heads, seq_length, size_per_head] + # the returned shape is [bs, num_heads, seq_length, size_per_head] query = self.transpose( F.reshape( query, (-1, query_tensor_original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) - # [bs, num_heads, size_per_head, seq_length] + # the returned shape is [bs, num_heads, size_per_head, seq_length] key = self.transpose( F.reshape( key, (-1, key_tensor_original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1)) - # [bs, num_heads, seq_length, size_per_head] + # the returned shape is [bs, num_heads, seq_length, size_per_head] value = self.transpose( F.reshape( value, @@ -949,7 +950,7 @@ class MultiHeadAttention(Cell): layer_present = (key_present, value_present) # multi head attention considering attention mask - # [bs, seq_length, hidden_size] + # the return shape is [bs, seq_length, hidden_size] attention = self._attn(query, key, value, attention_mask) # Output output = self.projection(attention) @@ -1019,8 +1020,8 @@ class MultiHeadAttention(Cell): ori_dtype = P.DType()(score) score = P.Cast()(score, self.softmax_dtype) - # for input size of (bs, 1) namely the second graph, the shape of attention_mask matrix should be - # (bs, 1, 1, seq_length) + # for input size of (bs, 1) namely the second graph, + # the shape of attention_mask matrix should be (bs, 1, 1, seq_length) if self.use_past and not self.is_first_iteration: # Calculate the current total token current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0), @@ -1508,7 +1509,7 @@ class TransformerDecoderLayer(Cell): memory_mask=None, init_reset=True, batch_valid_length=None): self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length) - # [bs, seq_length, embedding_size] + # the returned shape is [bs, seq_length, embedding_size] input_x = self.layernorm1(hidden_stats) input_x = F.cast(input_x, self.dtype) diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py index 2f68583fcef..7f235f4dcf5 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -33,20 +33,19 @@ class EmbeddingLayer(nn.Cell): def __init__(self, config): super(EmbeddingLayer, self).__init__() # Only for the pipeline mode, the embedding needs to be row sliced. - copied_parallel_config = copy.deepcopy(config.parallel_config) - if copied_parallel_config.pipeline_stage > 1: - copied_parallel_config.vocab_emb_dp = False self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size, embedding_size=config.hidden_size, param_init=initializer("normal", [config.vocab_size, config.hidden_size], dtype=config.param_init_type), - parallel_config=copied_parallel_config.embedding_dp_mp_config) + parallel_config=config.parallel_config.embedding_dp_mp_config) + copied_parallel_config = copy.deepcopy(config.parallel_config) + copied_parallel_config.vocab_emb_dp = True self.position_embedding = VocabEmbedding(vocab_size=config.seq_length, embedding_size=config.hidden_size, param_init=initializer("normal", [config.seq_length, config.hidden_size], dtype=config.param_init_type), - parallel_config=config.parallel_config.embedding_dp_mp_config) + parallel_config=copied_parallel_config.embedding_dp_mp_config) self.add = P.Add().shard( ((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1))) self.dropout = nn.Dropout(1 - config.dropout_rate) @@ -249,13 +248,14 @@ class PanguAlpha_Model(Cell): param_init_type=config.param_init_type, use_past=config.use_past, parallel_config=config.parallel_config).blocks - + copied_parallel_config = copy.deepcopy(config.parallel_config) + copied_parallel_config.vocab_emb_dp = True self.top_query_embedding = VocabEmbedding(vocab_size=config.seq_length, embedding_size=config.hidden_size, param_init=initializer("normal", [config.seq_length, config.hidden_size], dtype=config.param_init_type), - parallel_config=config.parallel_config.embedding_dp_mp_config) + parallel_config=copied_parallel_config.embedding_dp_mp_config) self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1 if config.parallel_config.pipeline_stage > 1: self.top_query_embedding.set_comm_fusion(2) diff --git a/model_zoo/official/nlp/pangu_alpha/train.py b/model_zoo/official/nlp/pangu_alpha/train.py index 9e7f177032a..3177f2312d0 100644 --- a/model_zoo/official/nlp/pangu_alpha/train.py +++ b/model_zoo/official/nlp/pangu_alpha/train.py @@ -106,6 +106,7 @@ def run_train(args_opt): pipeline_stage=args_opt.stage_num, micro_batch_num=args_opt.micro_size, optimizer_shard=bool(args_opt.optimizer_shard), + vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True) config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length, @@ -221,6 +222,7 @@ def run_train_pipeline(args_opt): pipeline_stage=args_opt.stage_num, micro_batch_num=args_opt.micro_size, optimizer_shard=bool(args_opt.optimizer_shard), + vocab_emb_dp=bool(args_opt.word_emb_dp), recompute=True) config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num, num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size,