diff --git a/model_zoo/official/nlp/pangu_alpha/predict.py b/model_zoo/official/nlp/pangu_alpha/predict.py index 7dd0972532c..d7147470ba3 100644 --- a/model_zoo/official/nlp/pangu_alpha/predict.py +++ b/model_zoo/official/nlp/pangu_alpha/predict.py @@ -43,11 +43,13 @@ def run_predict(args_opt): device_id = int(os.getenv('DEVICE_ID')) local_rank = rank_id print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True) + # Set execution mode context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) context.set_context(variable_memory_max_size="30GB") + # Set parallel context if args_opt.distribute == "true": D.init() device_num = D.get_group_size() @@ -71,6 +73,7 @@ def run_predict(args_opt): rank = 0 device_num = 1 + # Set model property model_parallel_num = args_opt.tensor_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) per_batch_size = args_opt.per_batch_size @@ -99,10 +102,13 @@ def run_predict(args_opt): print("=====args_opt is: ", args_opt, flush=True) ckpt_name = args_opt.load_ckpt_name + # Define network pangu_alpha = PanguAlpha(config) eval_net = EvalNet(pangu_alpha) eval_net.set_train(False) model_predict = Model(eval_net) + + # Compile network and obtain tensor layout for loading ckpt inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) predict_layout = model_predict.infer_predict_layout(inputs_np) print("======start load_distributed checkpoint", flush=True) @@ -111,19 +117,24 @@ def run_predict(args_opt): ckpt_file_list = [os.path.join(args_opt.load_ckpt_path, f"{ckpt_name}_{ckpt_rank}.ckpt") for ckpt_rank in range(0, 512)] print(f"Loading from path {ckpt_file_list[0]}", flush=True) + # Load checkpoint files load_distributed_checkpoint(eval_net, ckpt_file_list, predict_layout) print("================load param ok=================", flush=True) from src.tokenization_jieba import JIEBATokenizer from src.generate import generate + # Define tokenizer tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'), os.path.join(args_opt.tokenizer_path, 'vocab10.model')) + # Tokenize input sentence to ids sample = "今天是一个好天气" tokenized_token = tokenizer.tokenize(sample) start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token) input_ids = np.array(start_sentence).reshape(1, -1) + # Call inference output_ids = generate(model_predict, input_ids, config.seq_length, 9) + # Decode output ids to sentence output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist()) print('Output is:', output_samples, flush=True) diff --git a/model_zoo/official/nlp/pangu_alpha/src/dataset.py b/model_zoo/official/nlp/pangu_alpha/src/dataset.py index 28d281e2d85..ce2540b01d0 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/dataset.py +++ b/model_zoo/official/nlp/pangu_alpha/src/dataset.py @@ -40,16 +40,22 @@ def get_input_data(input_ids, eod_id, rank, dis): input_ids = input_ids[rank*dis: (rank+1)*dis] seq_length = input_ids.shape[1] - 1 + # Initialize position_ids and attention_mask batch_input_ids = input_ids batch_position_ids = np.ones((dis, seq_length)) batch_attention_mask = np.ones((dis, seq_length, seq_length)) + + # Loop through batches for bs_i, _ in enumerate(range(len(input_ids))): + # Get normal position_ids and attention_mask local_ids = input_ids[bs_i] batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length))) batch_position_ids[bs_i] = np.arange(seq_length) + # Find eod_of_document eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32) prev_index = 0 for i in range(eod_index.size): + # Reset position_ids and attention_mask considering EOD index = eod_index[i] batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0 batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index) @@ -76,6 +82,8 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_ dataset_restore: the dataset for training or evaluating """ ds.config.set_seed(1) + + # Get path for source data files home_path = os.path.join(os.getcwd(), data_path) files = os.listdir(data_path) dis = int(batch_size / device_num) @@ -89,9 +97,12 @@ def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_ if not name.endswith(".db") ] + # Load data files and preprocess dataset = ds.MindDataset(data[data_start_index:], columns_list=[column_name], shuffle=False) type_cast_op = C.TypeCast(mstype.int32) type_cast_op_float = C.TypeCast(mstype.float16) + + # If eod_reset enabled, another two inputs will be generated through input_ids if eod_reset: map_func = (lambda input_ids: get_input_data(input_ids, eod_id, rank, dis)) dataset = dataset.batch(batch_size, drop_remainder=drop) diff --git a/model_zoo/official/nlp/pangu_alpha/src/generate.py b/model_zoo/official/nlp/pangu_alpha/src/generate.py index 248a7b17a76..fc69156de40 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/generate.py +++ b/model_zoo/official/nlp/pangu_alpha/src/generate.py @@ -37,22 +37,30 @@ def generate(model, origin_inputs, seq_length, end_token=50256): seq_length = seq_length _, valid_length = origin_inputs.shape pad_length = seq_length - origin_inputs.shape[-1] + # Pad original inputs to seq_length input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0)) print("input_ids is ", input_ids) + + # A single loop generates one token, loop until reaching target seq_length or generating eod token while valid_length < seq_length: inputs = Tensor(input_ids, mstype.int32) + # Call a single inference probs, p_args = model.predict(inputs) + # Get the topk value and index for the current position from the padded outputs probs = probs.asnumpy()[valid_length-1, :] p_args = p_args.asnumpy()[valid_length-1, :] - + # Random select a token as final output for this round p = probs p = p / sum(p) target_index = np.random.choice(len(p), p=p) + # Stop judgment if p_args[target_index] == end_token or valid_length == seq_length-1: outputs = input_ids break + # Modify input_ids with newly generated token input_ids[0][valid_length] = p_args[target_index] valid_length += 1 + # Return valid outputs out of padded outputs length = np.sum(outputs != 0) outputs = outputs[0][:length] return outputs diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py index 5527d3ee888..5d8713c620d 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -203,7 +203,9 @@ class Output(nn.Cell): super(Output, self).__init__() input_size = config.embedding_size output_size = config.embedding_size * config.expand_ratio + # Project to expand_ratio*embedding_size self.mapping = Mapping_output(config, input_size, output_size) + # Project back to embedding_size self.projection = Mapping(config, output_size, input_size, scale) self.activation = nn.GELU() self.activation.gelu.shard(((config.dp, 1, config.mp),)) @@ -212,8 +214,10 @@ class Output(nn.Cell): self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) def construct(self, x): + # [bs, seq_length, expand_ratio*embedding_size] hidden = self.activation(self.mapping(x)) output = self.projection(hidden) + # [bs, seq_length, expand_ratio] output = self.dropout(output) return output @@ -235,6 +239,7 @@ class AttentionMask(nn.Cell): ((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1? self.expand_dim = P.ExpandDims().shard(((1, 1),)) ones = np.ones(shape=(config.seq_length, config.seq_length)) + # Default lower triangle mask matrix self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32) self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1))) @@ -245,12 +250,14 @@ class AttentionMask(nn.Cell): input_shape = P.Shape()(input_mask) shape_right = (input_shape[0], 1, input_shape[1]) shape_left = input_shape + (1,) + # Mask the padded inputs mask_left = self.reshape(input_mask, shape_left) mask_right = self.reshape(input_mask, shape_right) attention_mask = self.mul(mask_left, mask_right) lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0) + # [bs, seq_length, seq_length] attention_mask = self.multiply( - attention_mask, lower_traiangle) #bs seq_length seq_length + attention_mask, lower_traiangle) return attention_mask @@ -305,7 +312,9 @@ class Attention(nn.Cell): """ def __init__(self, config, scale=1.0, layer_idx=None): super(Attention, self).__init__() + # Attention mask matrix self.get_attention_mask = AttentionMask(config) + # Output layer self.projection = Mapping(config, config.embedding_size, config.embedding_size, scale) self.transpose = P.Transpose().shard(((config.dp, 1, config.mp, 1),)) @@ -313,6 +322,7 @@ class Attention(nn.Cell): ((config.dp, config.mp, 1, 1),)) self.reshape = P.Reshape() self.n_head = config.num_heads + # embedding size per head self.size_per_head = config.embedding_size // self.n_head self.concat_k = P.Concat(axis=3) self.concat_v = P.Concat(axis=2) @@ -329,6 +339,7 @@ class Attention(nn.Cell): ((config.dp, 1, 1, 1), (1,))) self.add = P.TensorAdd().shard( ((config.dp, 1, 1, 1), (config.dp, config.mp, 1, 1))) + # Normalize factor for attention, sqrt(dk) as widely used if self.scale: self.scale_factor = Tensor(math.sqrt(self.size_per_head)) if layer_idx is not None: @@ -347,16 +358,19 @@ class Attention(nn.Cell): self.softmax.softmax.shard(((config.dp, config.mp, 1),)) self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) + # Query self.dense1 = nn.Dense(config.embedding_size, config.embedding_size).to_float( config.compute_dtype) self.dense1.matmul.shard(((config.dp, 1), (config.mp, 1))) self.dense1.bias_add.shard(((config.dp, config.mp), (config.mp,))) + # Key self.dense2 = nn.Dense(config.embedding_size, config.embedding_size).to_float( config.compute_dtype) self.dense2.matmul.shard(((config.dp, 1), (config.mp, 1))) self.dense2.bias_add.shard(((config.dp, config.mp), (config.mp,))) + # Value self.dense3 = nn.Dense(config.embedding_size, config.embedding_size).to_float( config.compute_dtype) @@ -380,18 +394,22 @@ class Attention(nn.Cell): original_shape = F.shape(x) x = F.reshape(x, (-1, original_shape[-1])) + # Self attention: query, key, value are derived from the same inputs query = self.dense1(x) key = self.dense2(x) value = self.dense3(x) + # [bs, num_heads, seq_length, size_per_head] query = self.transpose( F.reshape( query, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) + # [bs, num_heads, size_per_head, seq_length] key = self.transpose( F.reshape( key, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 3, 1)) + # [bs, num_heads, seq_length, size_per_head] value = self.transpose( F.reshape( value, @@ -403,8 +421,11 @@ class Attention(nn.Cell): key = self.concat_k((past_key, key)) value = self.concat_v(past_value, value) layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) + # Self-attention considering attention mask attention = self._attn(query, key, value, attention_mask) + # [bs, seq_length, embedding_size] attention_merge = self.merge_heads(attention) + # Output output = self.projection(attention_merge) output = self.dropout(output) return output, layer_present @@ -454,11 +475,14 @@ class Attention(nn.Cell): Returns: weighted_values: Tensor, the weighted sum scores """ + # Normalize query and key before MatMul, default off if not self.scale: query = query / F.cast(self.coeff, F.dtype(query)) key = key / F.cast(self.coeff, F.dtype(key)) + # Attention score [bs, num_heads, seq_length, seq_length] score = self.batch_matmul(query, key) + # Normalize after query and key MatMul, default on if self.scale: score = self.real_div( score, @@ -466,6 +490,7 @@ class Attention(nn.Cell): ori_dtype = P.DType()(score) score = P.Cast()(score, mstype.float32) + # Minus 10000 for the position where masked to exclude them from softmax multiplu_out = self.sub( P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)), P.Cast()(attention_mask, P.DType()(score))) @@ -474,13 +499,15 @@ class Attention(nn.Cell): attention_scores = self.add(adder, score) shape = F.shape(attention_scores) + # attention probs attention_probs = self.softmax( F.reshape(attention_scores, - (shape[0], -1, shape[-1]))) # yzz modify + (shape[0], -1, shape[-1]))) attention_probs = P.Cast()(attention_probs, ori_dtype) attention_probs = F.reshape(attention_probs, shape) attention_probs = self.prob_dropout(attention_probs) + # Weighted sum output [bs, num_heads, seq_length, size_per_head] weighted_values = self.batch_matmul(attention_probs, value) return weighted_values @@ -517,11 +544,13 @@ class Block(nn.Cell): self.attention = Attention(config, scale, layer_idx) self.layernorm2.gamma.parallel_optimizer = False self.layernorm2.beta.parallel_optimizer = False + # Feed Forward Network, FFN self.output = Output(config, scale) self.post_layernorm_residual = config.post_layernorm_residual self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) self.last_add = P.TensorAdd().shard( ((config.dp, 1, 1), (config.dp, 1, 1))) + # Last activation of this layer will be saved for recompute in backward process self.last_add.recompute(False) self.dtype = config.compute_dtype @@ -529,12 +558,15 @@ class Block(nn.Cell): r""" The forward process of the block. """ + # [bs, seq_length, embedding_size] input_x = self.layernorm1(x) input_x = F.cast(input_x, self.dtype) attention, layer_present = self.attention(input_x, input_mask, layer_past) + # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm if self.post_layernorm_residual: x = self.add(input_x, attention) + # For pre-layernorm the inputs for residual path are output of self-attention and input of this layer else: x = self.add(x, attention) @@ -556,6 +588,7 @@ class QueryLayerAttention(Attention): original_shape = F.shape(x) x = F.reshape(x, (-1, original_shape[-1])) query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1])) + # For query_layer_attention, query are derived from outputs of previous layer and key, value are derived from an added parameter query_embedding query = self.dense1(query_hidden_state) key = self.dense2(x) value = self.dense3(x) @@ -611,7 +644,8 @@ class QueryLayer(nn.Cell): def construct(self, x, query_hidden_state, input_mask, layer_past=None): r""" - Query Layer. + Query Layer shares a similar structure with normal layer block + except that it is not a traditional self-attention. """ input_x = self.layernorm1(x) input_x = F.cast(input_x, self.dtype) @@ -650,6 +684,7 @@ class PanguAlpha_Model(nn.Cell): def __init__(self, config): super(PanguAlpha_Model, self).__init__() self.get_attention_mask = AttentionMask(config) + # Word embedding self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1) if config.load_ckpt_path: # Loading the embedding table from the ckpt path: @@ -662,6 +697,7 @@ class PanguAlpha_Model(nn.Cell): else: position_table_param = TruncatedNormal(0.02) + # Position embedding self.position_embedding = nn.Embedding( config.seq_length, config.embedding_size, @@ -671,11 +707,13 @@ class PanguAlpha_Model(nn.Cell): self.position_embedding.gather.shard(((1, 1), (config.dp,))) self.position_embedding.expand.shard(((config.dp, 1),)) self.blocks = nn.CellList() + # Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused. fusion_group_num = 4 fusion_group_size = config.num_layers // fusion_group_num fusion_group_size = max(fusion_group_size, 1) num_layers = config.num_layers + # If top_query_attention enabled, replace the last normal self-attention layers with this top_query_attention layer if config.use_top_query_attention: num_layers -= 1 self.num_layers = num_layers @@ -683,7 +721,10 @@ class PanguAlpha_Model(nn.Cell): for i in range(num_layers): per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2) + # Each layer will be remoputed in the backward process. The output activation of each layer will be saved, + # in other words, in backward process each block will be almosttotally recomputed. per_block.recompute() + # Dropout will not be recomputed to ensure the consistency between forward and the corresponding backward. per_block.attention.dropout.dropout_gen_mask.recompute(False) per_block.attention.prob_dropout.dropout_gen_mask.recompute(False) per_block.output.dropout.dropout_gen_mask.recompute(False) @@ -709,6 +750,9 @@ class PanguAlpha_Model(nn.Cell): self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) self.eod_reset = config.eod_reset + # If top_query_attention enabled, the input_position representing the position ids will be used as the index + # for a query embedding table to obtain top query hidden states, together with the previous outputs of normal + # self-attention layers, a new attention layer will be attached to the output of the model if config.use_top_query_attention: if config.load_ckpt_path: # Loading the embedding table from the ckpt path: @@ -745,19 +789,24 @@ class PanguAlpha_Model(nn.Cell): if not self.use_past: layer_past = self.past + # Word embedding input_embedding, embedding_table = self.word_embedding(input_ids) + # If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids + # and the corresponding input_position and attention_mask will be derived from it. if not self.eod_reset: batch_size, seq_length = F.shape(input_ids) input_position = F.tuple_to_array(F.make_range(seq_length)) input_position = P.Tile()(input_position, (batch_size, 1)) attention_mask = self.get_attention_mask(input_mask) position_embedding = self.position_embedding(input_position) + # Input features [bs, seq_length, embedding_size] hidden_states = self.add(input_embedding, position_embedding) hidden_states = self.dropout(hidden_states) hidden_states = P.Cast()(hidden_states, mstype.float16) attention_mask = self.expand_dims(attention_mask, 1) present_layer = () + # Loop through each self-attention layer for i in range(self.num_layers): hidden_states, present = self.blocks[i](hidden_states, attention_mask, layer_past) @@ -766,12 +815,12 @@ class PanguAlpha_Model(nn.Cell): output_state = self.layernorm(hidden_states) output_state = F.cast(output_state, self.dtype) + # Top query attention layer if self.use_top_query_attention: top_query_hidden_states = self.top_query_embedding(input_position) output_state, present = self.top_query_layer(output_state, top_query_hidden_states, attention_mask, layer_past) present_layer = present_layer + (present,) - return output_state, present_layer, embedding_table @@ -799,6 +848,7 @@ class PanguAlpha_Head(nn.Cell): def construct(self, state, embedding_table): state = P.Reshape()(state, (-1, self.embedding_size)) + # output logits over vocabulary [bs*seq_length, vocab_size] logits = self.matmul(state, self.cast(embedding_table, self.dtype)) return logits @@ -817,7 +867,9 @@ class PanguAlpha(nn.Cell): """ def __init__(self, config): super(PanguAlpha, self).__init__() + # Network backbone of PanguAlpha self.backbone = PanguAlpha_Model(config) + # Network head to get logits over vocabulary self.head = PanguAlpha_Head(config) def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None): @@ -844,6 +896,7 @@ class CrossEntropyLoss(nn.Cell): self.mean = P.ReduceMean() self.sum = P.ReduceSum().shard(((config.dp, config.mp),)) self.onehot = P.OneHot().shard(((config.dp, config.mp), (), ())) + # on/off value for onehot, for smooth labeling, modify the off_value self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.vocab_size = config.vocab_size @@ -868,7 +921,9 @@ class CrossEntropyLoss(nn.Cell): r""" Compute loss using logits, label and input mask """ + # [bs*seq_length, vocab_size] logits = F.cast(logits, mstype.float32) + # LogSoftmax for logits over last dimension _, logit_max = self.max(logits) logit_sub = self.sub(logits, logit_max) logit_exp = self.exp(logit_sub) @@ -876,12 +931,17 @@ class CrossEntropyLoss(nn.Cell): exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1)) softmax_result = self.div(logit_exp, exp_sum) log_softmax_result = self.log(self.add(softmax_result, self.eps_const)) + + # Flatten label to [bs*seq_length] label = P.Reshape()(label, (-1,)) + # Get onehot label [bs*seq_length, vocab_size] one_hot_label = self.onehot(label, self.vocab_size, self.on_value, self.off_value) + # Cross-Entropy loss loss = self.mul(log_softmax_result, one_hot_label) loss_unsum = self.neg(loss) loss_reduce = self.sum(loss_unsum, -1) + # input_mask indicates whether there is padded inputs and for padded inputs it will not be counted into loss input_mask = P.Reshape()(input_mask, (-1,)) numerator = self.sum2(self.mul2(loss_reduce, input_mask)) @@ -909,6 +969,7 @@ class PanguAlphaWithLoss(nn.Cell): super(PanguAlphaWithLoss, self).__init__(auto_prefix=False) self.network = network self.loss = loss + # id for end_of_sentence, 6 in the vocabulary self.eos_token = eos_token self.slice = P.StridedSlice().shard(((config.dp, 1),)) self.not_equal = P.NotEqual().shard(((config.dp, 1), ())) @@ -922,6 +983,10 @@ class PanguAlphaWithLoss(nn.Cell): r""" PanguAlphaWithLoss """ + # input_ids [bs, seq_length+1] + # input_position [bs, seq_length] only available when eod_reset enabled + # attention_mask [bs, seq_length, seq_length] only available when eod-reset enabled + # Get input tokens [bs, seq_length] tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1)) if self.eod_reset: @@ -929,12 +994,14 @@ class PanguAlphaWithLoss(nn.Cell): attention_mask = self.slice_mask(attention_mask, (0, 0, 0), (self.batch_size, self.len, self.len), (1, 1, 1)) - + # Check whether there is padding in inputs input_mask = F.cast(self.not_equal(tokens, self.eos_token), mstype.float32) logits = self.network(tokens, input_mask, input_position, attention_mask) + # Get label corresponding to input tokens labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1), (1, 1)) + # Loss output = self.loss(logits, labels, input_mask) return output diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py index 1da5a6c646a..c45f9128cbc 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_config.py @@ -50,13 +50,17 @@ class PANGUALPHAConfig: self.embedding_size = embedding_size self.num_layers = num_layers self.num_heads = num_heads + # The expand ratio of feature size in FFN self.expand_ratio = expand_ratio + # Use post-layernorm or pre-layernrom, default:pre-layernorm self.post_layernorm_residual = post_layernorm_residual self.dropout_rate = dropout_rate self.compute_dtype = compute_dtype + # Whether use incremental inference self.use_past = use_past self.dp = data_parallel_num self.mp = model_parallel_num + # Whether use self implemented layernorm self.self_layernorm = self_layernorm self.stage_num = stage_num self.micro_size = micro_size diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py index 0265aa69f39..4f34123552b 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha_wrapcell.py @@ -45,6 +45,7 @@ def _clip_grad(clip_type, clip_value, grad): if clip_type not in [0, 1]: return grad dt = F.dtype(grad) + # 0 for clip_by_value and 1 for clip_by_norm if clip_type == 0: new_grad = C.clip_by_value( grad, F.cast(F.tuple_to_array((-clip_value,)), dt), @@ -107,12 +108,14 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None): """Defines the computation performed.""" weights = self.weights + # Forward process loss = self.network(input_ids, input_position, attention_mask) scaling_sens = self.scale_sense # alloc status and clear should be right before gradoperation status, scaling_sens = self.start_overflow_check(loss, scaling_sens) scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) + # Backward process using loss scale grads = self.grad(self.network, weights)(input_ids, input_position, attention_mask, @@ -129,8 +132,11 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell): grads = self.hyper_map( F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + # Check whether overflow cond = self.get_overflow_status(status, grads) overflow = self.process_loss_scale(cond) + # If overflow, surpass weights update + # if not, update weights if overflow: succ = False else: diff --git a/model_zoo/official/nlp/pangu_alpha/src/utils.py b/model_zoo/official/nlp/pangu_alpha/src/utils.py index 4074b98539b..c9c4e017778 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/utils.py +++ b/model_zoo/official/nlp/pangu_alpha/src/utils.py @@ -70,7 +70,9 @@ class GlobalNorm(nn.Cell): self.values.append(Tensor([self.group_size*1.0], mstype.float32)) self.values = tuple(self.values) def construct(self, grads): + # Square sum of gradients for current rank square_sum_dp = self.hyper_map(get_square_sum, grads, self.values) + # Global square sum of gradients global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp))) return global_norms diff --git a/model_zoo/official/nlp/pangu_alpha/train.py b/model_zoo/official/nlp/pangu_alpha/train.py index a3a26e76404..991a2555a26 100644 --- a/model_zoo/official/nlp/pangu_alpha/train.py +++ b/model_zoo/official/nlp/pangu_alpha/train.py @@ -77,10 +77,12 @@ def run_train(args_opt): The main training process. """ device_id = int(os.getenv('DEVICE_ID')) + # Set execution mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) context.set_context(variable_memory_max_size="30GB") + # Set parallel context if args_opt.distribute == "true": D.init() device_num = D.get_group_size() @@ -102,6 +104,8 @@ def run_train(args_opt): else: rank = 0 device_num = 1 + + # Set model property model_parallel_num = args_opt.tensor_model_parallel_num data_parallel_num = int(device_num / model_parallel_num) batch_size = args_opt.per_batch_size * device_num @@ -124,18 +128,23 @@ def run_train(args_opt): eod_reset=bool(args_opt.eod_reset), word_emb_dp=True) print("===config is: ", config, flush=True) + + # Define network pangu_alpha = PanguAlpha(config) loss = CrossEntropyLoss(config) pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss) pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) print("=====args_opt is: ", args_opt, flush=True) + + # Warm-up and cosine decay learning rate lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, warmup_steps=args_opt.warmup_step, decay_steps=200000, lr_scale=1) + # Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower() params = pangu_alpha.trainable_params() decay_params = list(filter(decay_filter, params)) @@ -153,8 +162,10 @@ def run_train(args_opt): optimizer = nn.Lamb(group_params, learning_rate=lr) else: optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95) + # Initial scaling sens loss_scale_value = math.pow(2, 32) epoch_num = args_opt.epoch_size + # Dataset loading mindrecord files ds = create_dataset(config.batch_size, data_path=args_opt.data_url, data_start_index=0, eod_reset=config.eod_reset, eod_id=args_opt.eod_id, device_num=device_num, rank=rank, epoch=epoch_num)