diff --git a/model_zoo/official/nlp/pangu_alpha/predict.py b/model_zoo/official/nlp/pangu_alpha/predict.py index a27aa87b239..15acea75a29 100644 --- a/model_zoo/official/nlp/pangu_alpha/predict.py +++ b/model_zoo/official/nlp/pangu_alpha/predict.py @@ -111,7 +111,16 @@ def run_predict(args_opt): # Compile network and obtain tensor layout for loading ckpt inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32) - predict_layout = model_predict.infer_predict_layout(inputs_np) + current_index = Tensor(np.array([0]), mstype.int32) + + if config.use_past: + batch_valid_length = Tensor(np.array([0]), mstype.int32) + init_true = Tensor([True], mstype.bool_) + inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32) + predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length) + _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length) + else: + predict_layout = model_predict.infer_predict_layout(inputs_np, current_index) print("======start load_distributed checkpoint", flush=True) # For 2.6B and 13B models, the number of ckpt files is 512. ckpt_name = 'filerted' @@ -123,7 +132,7 @@ def run_predict(args_opt): print("================load param ok=================", flush=True) from src.tokenization_jieba import JIEBATokenizer - from src.generate import generate + from src.generate import generate, generate_increment # Define tokenizer tokenizer = JIEBATokenizer(os.path.join(args_opt.tokenizer_path, 'vocab10.vocab'), os.path.join(args_opt.tokenizer_path, 'vocab10.model')) @@ -134,7 +143,8 @@ def run_predict(args_opt): start_sentence = tokenizer.convert_tokens_to_ids(tokenized_token) input_ids = np.array(start_sentence).reshape(1, -1) # Call inference - output_ids = generate(model_predict, input_ids, opt) + generate_func = generate_increment if config.use_past else generate + output_ids = generate_func(model_predict, input_ids, opt) # Decode output ids to sentence output_samples = tokenizer.convert_ids_to_tokens(output_ids.tolist()) print('Output is:', output_samples, flush=True) diff --git a/model_zoo/official/nlp/pangu_alpha/src/generate.py b/model_zoo/official/nlp/pangu_alpha/src/generate.py index a729dc9b23e..a9cefb87174 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/generate.py +++ b/model_zoo/official/nlp/pangu_alpha/src/generate.py @@ -24,7 +24,7 @@ from mindspore.ops import operations as P def generate(model, origin_inputs, config): """ - TopK for text generation + Text generation Inputs: model: the model for inferencing @@ -59,7 +59,7 @@ def generate(model, origin_inputs, config): while valid_length < target_length: inputs = Tensor(input_ids, mstype.int32) # Indicate the exact token position - current_index = valid_length - 1 if valid_lenght - 1 > 0 else 0 + current_index = valid_length - 1 if valid_length - 1 > 0 else 0 current_index = Tensor([current_index], mstype.int32) # Call a single inference log_probs = model.predict(inputs, current_index) @@ -103,6 +103,10 @@ def generate(model, origin_inputs, config): if p_args[target_index] == end_token or valid_length == target_length-1: outputs = input_ids break + + # update frequency list + target = p_args[target_index] + frequency_list[0][target] = frequency_list[0][target] + 1 # Modify input_ids with newly generated token input_ids[0][valid_length] = p_args[target_index] valid_length += 1 @@ -110,3 +114,117 @@ def generate(model, origin_inputs, config): length = np.sum(outputs != 0) outputs = outputs[0][:length] return outputs + +def generate_increment(model, origin_inputs, config): + """ + Text generation for incremental inference + + Inputs: + model: the model for inferencing + origin_inputs: the original inputs based on which the model will continue writing + config: inference configurations + + Returns: + outputs: the ids for the generated text + """ + # Get configurations for inference + frequency_penalty = config.frequency_penalty + presence_penalty = config.presence_penalty + top_p = config.top_p + top_k_num = config.top_k_num + max_generate_length = config.max_generate_length + seq_length = config.seq_length + end_token = config.end_token + + _, valid_length = origin_inputs.shape + # Init outputs with original inputs + outputs = [origin_inputs[0][i] for i in range(valid_length)] + # If target length exceeds seq_length, use seq_length instead + target_length = valid_length + max_generate_length + target_length = seq_length if target_length > seq_length else target_length + + # A list of the frequency of each token + frequency_list = np.array([[0 for _ in range(config.vocab_size)]]) + pad_length = seq_length - origin_inputs.shape[-1] + # Pad original inputs to seq_length + input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0)) + print("input_ids is ", input_ids) + + # Indicate the exact token position + current_index = valid_length - 1 if valid_length - 1 > 0 else 0 + current_index = Tensor(np.array([current_index]), mstype.int32) + batch_valid_length = Tensor(np.array([current_index]), mstype.int32) + # For first graph, not_init should be false + init_true = Tensor([True], mstype.bool_) + init_false = Tensor([False], mstype.bool_) + init = init_false + # Claim the first graph + model.predict_network.add_flags_recursive(is_first_iteration=True) + # Call a single inference with input size of (bs, seq_length) + logits = model.predict(Tensor(input_ids, mstype.int32), init, batch_valid_length, current_index) + + # Claim the second graph and set not_init to true + init = init_true + model.predict_network.add_flags_recursive(is_first_iteration=False) + + # A single loop generates one token, loop until reaching target seq_length or generating eod token + while valid_length < target_length: + # Reshape the output logits + logits = logits.asnumpy() + log_probs = logits.reshape(1, vocab_size) + + # Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results + log_probs = log_probs.asnumpy().reshape(1, config.vocab_size) + log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty + + # Convert the log_probs to probability + logits = P.Pow()(10, Tensor(log_probs_revised, mstype.float32)) + + # If top_p is less than 1.0, use top_p sampling + if top_p < 1.0: + # Only consider the 5000 largest logits to reduce computation + sorted_logits, index = P.TopK(sorted=True)(logits, 5000) + cumsum_logits = P.CumSum()(sorted_logits, 1) + cumsum_logits = cumsum_logits.asnumpy()[0] + index = index.asnumpy()[0] + sorted_logits = sorted_logits.asnumpy()[0] + top_p_num = sum(cumsum_logits > top_p) + # In case the probability is smooth, the sum of 5000 largest probabilities are not large enough + if top_p_num == 0: + top_p_num = 5000 + # Get the corresponding probs and indices + probs = sorted_logits[:top_p_num] + p_args = index[:top_p_num] + p = probs / sum(probs) + # if top_p is set to 1.0, use top_k sampling + else: + # Get the corresponding probs and indices + probs, p_args = P.TopK(sorted=True)(logits, top_k_num) + probs = probs.asnumpy()[0] + p_args = p_args.asnumpy()[0] + # Avoid rounding error + if sum(probs) == 0: + probs = np.array([1 / top_k_num for _ in range(top_k_num)]) + p = probs / sum(probs) + + # Random select a token as final output for this round + target_index = np.random.choice(len(p), p=p) + # Stop judgment + if p_args[target_index] == end_token or valid_length == target_length-1: + break + + # Update frequency list + target = p_args[target_index] + frequency_list[0][target] = frequency_list[0][target] + 1 + valid_length += 1 + + batch_valid_length = Tensor(np.array([valid_length - 1]), mstype.int32) + current_index = Tensor(np.array([0]), mstype.int32) + input_id = Tensor([[target]], mstype.int32) + # Update outputs with current generated token + outputs.append(int(target)) + + # Call a single inference with input size of (bs, 1) + logits = model.predict(input_id, current_index, init, batch_valid_length) + # Return valid outputs out of padded outputs + return np.array(outputs) diff --git a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py index f19102b1254..e1cb9cc3aa2 100644 --- a/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py +++ b/model_zoo/official/nlp/pangu_alpha/src/pangu_alpha.py @@ -420,7 +420,28 @@ class Attention(nn.Cell): self.dense3.matmul.shard(((config.dp, 1), (config.mp, 1))) self.dense3.bias_add.shard(((config.dp, config.mp), (config.mp,))) - def construct(self, x, attention_mask, layer_past=None): + self.is_first_iteration = True + self.dtype = config.compute_dtype + self.use_past = config.use_past + if self.use_past: + # operators used for state reuse + seq_range = np.arange(config.seq_length).reshape(1, 1, -1) + self.range = Tensor(np.tile(seq_range, (config.batch_size, 1, 1)), mstype.int32) + self.seq_length = config.seq_length + self.attention_mask = Tensor(np.tril(np.ones(shape=(self.seq_length, self.seq_length))), mstype.int32) + self.slice = P.StridedSlice().shard(((1, 1, 1, 1),)) + self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ())) + self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) + self.expand_dims = P.ExpandDims().shard(((1, 1, 1),)) + self.tensor_le = P.LessEqual().shard(((1, 1, 1), (1, 1, 1))) + self.add = P.TensorAdd().shard(((1, 1, 1, 1), (1, 1, 1, 1))) + self.equal = P.Equal().shard(((1, 1, 1), (1, 1, 1))) + self.sub1 = P.Sub().shard(((1,), ())) + self.tile = P.Tile().shard(((1, 1, 1, 1),)) + self.less = P.Less().shard(((1, 1, 1), (1, 1, 1))) + self.mul1 = P.Mul().shard(((1, 1, 1, 1), (1, 1, 1, 1))) + + def construct(self, x, attention_mask, key_past=None, value_past=None, batch_valid_length=None): """ self-attention @@ -428,7 +449,9 @@ class Attention(nn.Cell): x: output of previous layer attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) - layer_past: the previous feature map + key_past: previous saved key state + value_past: previous saved value state + batch_valid_length: the valid input seq_length without padding Returns: output: Tensor, the output logit of this layer @@ -458,12 +481,44 @@ class Attention(nn.Cell): value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) + + # key and value for current token(s) + key_present = key + value_present = value if self.use_past: - past_value = layer_past[1] - past_key = self.transpose(layer_past[0], (0, 1, 3, 2)) - key = self.concat_k((past_key, key)) - value = self.concat_v(past_value, value) - layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) + # The first graph with the input size of (bs, seq_length) + if self.is_first_iteration: + # Get the valid input length without padding + valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype) + # Cover the key and value numbers corresponding to the padding position + key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2)) + value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3)) + # The second graph with the inpus size of (bs, 1) + # the shape of query is (bs, num_heads, 1, size_per_head) + # the shape of key is (bs, num_heads, size_per_head, 1) + # the shape of value is (bs, num_heads, 1, size_per_head) + else: + # Get the current token position index + valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0), + (F.shape(x)[0], 1, 1, self.seq_length), + (1, 1, 1, 1)), + 0), mstype.float32), (1, 2, 3)) + valid_length = F.reshape(valid_length, (-1, 1, 1)) + valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype) + # Pad the key and value to seq_length with only the position index not zero + current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)), + self.expand_dims(valid_length_vector, 2)) + current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)), + self.expand_dims(valid_length_vector, 3)) + # Concat the previous saved state and current state + key = self.add(key_past, current_key) + value = self.add(value_past, current_value) + # Update key_present and value_present for state update + key_present = key + value_present = value + attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1)) + + layer_present = (key_present, value_present) # Self-attention considering attention mask attention = self._attn(query, key, value, attention_mask) # [bs, seq_length, embedding_size] @@ -523,7 +578,7 @@ class Attention(nn.Cell): query = query / F.cast(self.coeff, F.dtype(query)) key = key / F.cast(self.coeff, F.dtype(key)) - # Attention score [bs, num_heads, seq_length, seq_length] + # Attention score [bs, num_heads, seq_length_q, seq_length_k] score = self.batch_matmul(query, key) # Normalize after query and key MatMul, default on if self.scale: @@ -533,6 +588,22 @@ class Attention(nn.Cell): ori_dtype = P.DType()(score) score = P.Cast()(score, mstype.float32) + + # for input size of (bs, 1) namely the second graph, the shape of attention_mask matrix should be + # (bs, 1, 1, seq_length) + if self.use_past and not self.is_first_iteration: + # Calculate the current total token + current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0), + (F.shape(query)[0], 1, 1, self.seq_length), + (1, 1, 1, 1)), + 0), mstype.float32), (1, 2, 3)) + # Get the precise position index + index = self.sub1(F.cast(current_index, mstype.int32), 1) + index = F.reshape(index, (-1, 1, 1)) + # Calculate the attention_mask matrix via the position index + attention_mask = F.cast(self.tensor_le(self.range, index), mstype.int32) + attention_mask = self.expand_dims(attention_mask, 2) + # Minus 10000 for the position where masked to exclude them from softmax multiplu_out = self.sub( P.Cast()(F.tuple_to_array((1.0,)), P.DType()(score)), @@ -550,27 +621,28 @@ class Attention(nn.Cell): attention_probs = F.reshape(attention_probs, shape) attention_probs = self.prob_dropout(attention_probs) - # Weighted sum output [bs, num_heads, seq_length, size_per_head] + # Weighted sum output [bs, num_heads, seq_length_q, size_per_head] weighted_values = self.batch_matmul(attention_probs, value) return weighted_values -class Block(nn.Cell): +class Decoder(nn.Cell): """ - The basic block of PanguAlpha network + The basic decoder structure of PanguAlpha network Args: config(PanguAlphaConfig): the config of network layer_idx: current layer index Inputs: x: the output of previous layer(input_ids for the first layer) attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length) - layer_past: the previous feature map + init_reset: whether reset the previous state + batch_valid_length: the valid input seq_length without padding Returns: output: Tensor, the output logit of this layer layer_present: Tensor, the feature map of current layer """ def __init__(self, config, layer_idx): - super(Block, self).__init__() + super(Decoder, self).__init__() scale = 1 / math.sqrt(2.0 * config.num_layers) if config.self_layernorm: @@ -593,16 +665,43 @@ class Block(nn.Cell): self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) # Last activation of this layer will be saved for recompute in backward process self.dtype = config.compute_dtype + self.use_past = config.use_past + if self.use_past: + # operator used for state reuse + self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) + self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ())) + self.slice = P.StridedSlice().shard(((1, 1, 1, 1),)) + size_per_head = int(config.embedding_size / config.num_heads) + self.key_shape = (config.batch_size, config.num_heads, size_per_head, config.seq_length) + self.value_shape = (config.batch_size, config.num_heads, config.seq_length, size_per_head) + # parameters saving key and value states + self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past") + self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past") + self.tile = P.Tile().shard(((1, 1),)) + self.mul = P.Mul().shard(((1, 1, 1, 1), (1,))) + self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1))) - def construct(self, x, input_mask, layer_past=None): + def construct(self, x, input_mask, init_reset=True, batch_valid_length=None): r""" The forward process of the block. """ # [bs, seq_length, embedding_size] input_x = self.layernorm1(x) input_x = F.cast(input_x, self.dtype) + + # indicate whether reset saved states + key_reset = None + value_reset = None + + if self.use_past: + # reset states, init_reset True for reuse and False for reset + key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype))) + value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype))) + # add dependency for desired execution order + input_x = F.depend(input_x, key_reset) + input_x = F.depend(input_x, value_reset) attention, layer_present = self.attention(input_x, input_mask, - layer_past) + self.key_past, self.value_past, batch_valid_length) # For post-layernorm the inputs for residual path are output of self-attention and output of layernorm if self.post_layernorm_residual: x = self.add(input_x, attention) @@ -613,6 +712,22 @@ class Block(nn.Cell): output_x = self.layernorm2(x) output_x = F.cast(output_x, self.dtype) mlp_logit = self.output(output_x) + + value_update = None + key_update = None + if self.use_past: + # current key and value + key_present, value_present = layer_present + # update key and value calculated this step + key_update = self.assign(self.key_past, key_present) + value_update = self.assign(self.value_past, value_present) + # add dependency for desired execution order + key_update = F.depend(key_update, key_reset) + value_update = F.depend(value_update, value_reset) + + # add dependency for desired execution order + mlp_logit = F.depend(mlp_logit, value_update) + mlp_logit = F.depend(mlp_logit, key_update) if self.post_layernorm_residual: output = self.add(output_x, mlp_logit) else: @@ -635,9 +750,14 @@ class PanguAlpha_EmbeddingPipeLine(nn.Cell): self.dropout = Dropout(1 - config.dropout_rate) self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) + self.use_past = config.use_past + self.is_first_iteration = True - def construct(self, input_ids, table, input_position): + def construct(self, input_ids, table, input_position, valid_index=None): input_embedding = self.word_embedding(input_ids, table) + if self.use_past and not self.is_first_iteration: + _, seq_length = F.shape(input_ids) + input_position = valid_index.view(1, seq_length) position_embedding = self.position_embedding(input_position) hidden_states = self.add(input_embedding, position_embedding) hidden_states = self.dropout(hidden_states) @@ -662,7 +782,7 @@ class QueryLayerAttention(Attention): r""" Self-Attention module using input query vector. """ - def construct(self, x, query_hidden_state, attention_mask, layer_past=None): + def construct(self, x, query_hidden_state, attention_mask, key_past=None, value_past=None, batch_valid_length=None): original_shape = F.shape(x) x = F.reshape(x, (-1, original_shape[-1])) query_hidden_state = F.reshape(query_hidden_state, (-1, original_shape[-1])) @@ -684,12 +804,42 @@ class QueryLayerAttention(Attention): value, (-1, original_shape[1], self.n_head, self.size_per_head)), (0, 2, 1, 3)) + + key_present = key + value_present = value if self.use_past: - past_value = layer_past[1] - past_key = self.transpose(layer_past[0], (0, 1, 3, 2)) - key = self.concat_k((past_key, key)) - value = self.concat_v(past_value, value) - layer_present = P.Pack()([self.transpose(key, (0, 1, 3, 2)), value]) + # The first graph with the input size of (bs, seq_length) + if self.is_first_iteration: + # Get the valid input length without padding + valid_length_vector = F.cast(self.less(self.range, batch_valid_length), self.dtype) + # Cover the key and value numbers corresponding to the padding position + key_present = self.mul1(key, self.expand_dims(valid_length_vector, 2)) + value_present = self.mul1(value, self.expand_dims(valid_length_vector, 3)) + # The second graph with the inpus size of (bs, 1) + # the shape of query is (bs, num_heads, 1, size_per_head) + # the shape of key is (bs, num_heads, size_per_head, 1) + # the shape of value is (bs, num_heads, 1, size_per_head) + else: + # Get the current token position index + valid_length = self.reducesum(F.cast(self.not_equal(self.slice(key_past, (0, 0, 0, 0), + (F.shape(x)[0], 1, 1, self.seq_length), + (1, 1, 1, 1)), + 0), mstype.float32), (1, 2, 3)) + valid_length = F.reshape(valid_length, (-1, 1, 1)) + valid_length_vector = F.cast(self.equal(valid_length, self.range), self.dtype) + # Pad the key and value to seq_length with only the position index not zero + current_key = self.mul1(self.tile(key, (1, 1, 1, self.seq_length)), + self.expand_dims(valid_length_vector, 2)) + current_value = self.mul1(self.tile(value, (1, 1, self.seq_length, 1)), + self.expand_dims(valid_length_vector, 3)) + # Concat the previous saved state and current state + key = self.add(key_past, current_key) + value = self.add(value_past, current_value) + # Update key_present and value_present for state update + key_present = key + value_present = value + attention_mask = F.reshape(self.attention_mask, (self.seq_length, self.seq_length, 1, 1)) + layer_present = (key_present, value_present) attention = self._attn(query, key, value, attention_mask) attention_merge = self.merge_heads(attention) output = self.projection(attention_merge) @@ -715,18 +865,48 @@ class QueryLayer(nn.Cell): self.post_layernorm_residual = config.post_layernorm_residual self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) self.dtype = config.compute_dtype + self.use_past = config.use_past + if self.use_past: + # operator used for state reuse + self.reducesum = P.ReduceSum().shard(((1, 1, 1, 1),)) + self.not_equal = P.NotEqual().shard(((1, 1, 1, 1), ())) + self.slice = P.StridedSlice().shard(((1, 1, 1, 1),)) + size_per_head = int(config.embedding_size / config.num_heads) + self.key_shape = (config.batch_size, config.num_heads, size_per_head, config.seq_length) + self.value_shape = (config.batch_size, config.num_heads, config.seq_length, size_per_head) + # parameters saving key and value states + self.key_past = Parameter(Tensor(np.zeros(shape=self.key_shape), self.dtype), name="key_past") + self.value_past = Parameter(Tensor(np.zeros(shape=self.value_shape), self.dtype), name="value_past") + self.tile = P.Tile().shard(((1, 1),)) + self.mul = P.Mul().shard(((1, 1, 1, 1), (1,))) + self.assign = P.Assign().shard(((1, 1, 1, 1), (1, 1, 1, 1))) - def construct(self, x, query_hidden_state, input_mask, layer_past=None): + def construct(self, x, query_hidden_state, input_mask, init_reset=True, batch_valid_length=None): r""" Query Layer shares a similar structure with normal layer block except that it is not a traditional self-attention. """ input_x = self.layernorm1(x) input_x = F.cast(input_x, self.dtype) + + # indicate whether reset saved states + key_reset = None + value_reset = None + + if self.use_past: + # reset states, init_reset True for reuse and False for reset + key_reset = self.assign(self.key_past, self.mul(self.key_past, F.cast(init_reset, self.dtype))) + value_reset = self.assign(self.value_past, self.mul(self.value_past, F.cast(init_reset, self.dtype))) + # add dependency for desired execution order + input_x = F.depend(input_x, key_reset) + input_x = F.depend(input_x, value_reset) + attention, layer_present = self.attention(input_x, query_hidden_state, input_mask, - layer_past) + self.key_past, + self.value_past, + batch_valid_length) if self.post_layernorm_residual: x = self.add(input_x, attention) else: @@ -735,28 +915,46 @@ class QueryLayer(nn.Cell): output_x = self.layernorm2(x) output_x = F.cast(output_x, self.dtype) mlp_logit = self.output(output_x) + value_update = None + key_update = None + if self.use_past: + # current key and value + key_present, value_present = layer_present + # update key and value calculated this step + key_update = self.assign(self.key_past, key_present) + value_update = self.assign(self.value_past, value_present) + # add dependency for desired execution order + key_update = F.depend(key_update, key_reset) + value_update = F.depend(value_update, value_reset) + + # add dependency for desired execution order + mlp_logit = F.depend(mlp_logit, value_update) + mlp_logit = F.depend(mlp_logit, key_update) + if self.post_layernorm_residual: output = self.add(output_x, mlp_logit) else: output = self.add(x, mlp_logit) return output, layer_present -class PanguAlpha_Model(nn.Cell): +class Embedding(nn.Cell): """ - The backbone of PanguAlpha network + Input embedding, i.e., word embedding and position embedding Args: config(PanguAlphaConfig): the config of network Inputs: input_ids: the tokenized inputs with datatype int32 input_mask: the mask indicating whether each position is a valid input - layer_past: the previous feature map - Returns: - output_state: Tensor, the output logit of backbone - present_layer: Tensor, the current feature map - embedding_table: Tensor, the embedding table for the vocabulary + input_position: the position index of each token + attention_mask: the attention_mask attention for self-attention module + valid_index: only used in incremental inference, the position index of current token + outputs: + hidden_states: Tensor, input embeddings + attention_mask: Tensor, attention_mask matrix + embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size) """ def __init__(self, config): - super(PanguAlpha_Model, self).__init__() + super(Embedding, self).__init__() self.get_attention_mask = AttentionMask(config) # Word embedding self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1) @@ -780,6 +978,59 @@ class PanguAlpha_Model(nn.Cell): self.position_embedding.embedding_table.parallel_optimizer = False self.position_embedding.gather.shard(((1, 1), (config.dp,))) self.position_embedding.expand.shard(((config.dp, 1),)) + self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) + self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) + self.dropout = Dropout(1 - config.dropout_rate) + self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) + self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) + self.eod_reset = config.eod_reset + self.use_past = config.use_past + self.is_first_iteration = True + + def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, valid_index=None): + """ + Calculate input embeddings via input token ids and input position + """ + # Word embedding + input_embedding, embedding_table = self.word_embedding(input_ids) + # If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids + # and the corresponding input_position and attention_mask will be derived from it. + if not self.eod_reset: + batch_size, seq_length = F.shape(input_ids) + attention_mask = self.get_attention_mask(input_mask) + if self.use_past and not self.is_first_iteration: + input_position = valid_index.view(1, seq_length) + else: + input_position = F.tuple_to_array(F.make_range(seq_length)) + input_position = P.Tile()(input_position, (batch_size, 1)) + position_embedding = self.position_embedding(input_position) + # Input features [bs, seq_length, embedding_size] + hidden_states = self.add(input_embedding, position_embedding) + hidden_states = self.dropout(hidden_states) + hidden_states = P.Cast()(hidden_states, mstype.float16) + attention_mask = self.expand_dims(attention_mask, 1) + return hidden_states, attention_mask, embedding_table + + +class PanguAlpha_Model(nn.Cell): + """ + The backbone of PanguAlpha network + Args: + config(PanguAlphaConfig): the config of network + Inputs: + input_ids: the tokenized inputs with datatype int32 + input_mask: the mask indicating whether each position is a valid input + input_position: the position index of each token + attention_mask: the attention_mask attention for self-attention module + init_reset: whether reset saved key and value states + batch_valid_length: the valid input sequence length without padding + Returns: + output_state: Tensor, the output logit of backbone + embedding_table: Tensor, the embedding table for the vocabulary + """ + def __init__(self, config): + super(PanguAlpha_Model, self).__init__() + self.embedding = Embedding(config) self.blocks = nn.CellList() # Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused. fusion_group_num = 4 @@ -794,9 +1045,9 @@ class PanguAlpha_Model(nn.Cell): print("After setting the layer is:", num_layers, flush=True) for i in range(num_layers): - per_block = Block(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2) - # Each layer will be recomputed in the backward process. The output activation of each layer will be saved, - # in other words, in backward process each block will be almost totally recomputed. + per_block = Decoder(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2) + # Each layer will be remoputed in the backward process. The output activation of each layer will be saved, + # in other words, in backward process each block will be almosttotally recomputed. if config.use_recompute: per_block.recompute() self.blocks.append(per_block) @@ -813,13 +1064,7 @@ class PanguAlpha_Model(nn.Cell): self.layernorm.beta.parallel_optimizer = False self.use_past = config.use_past self.past = tuple([None] * config.num_layers) - self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) - self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) self.dtype = config.compute_dtype - self.dropout = Dropout(1 - config.dropout_rate) - self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) - self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),)) - self.eod_reset = config.eod_reset # If top_query_attention enabled, the input_position representing the position ids will be used as the index # for a query embedding table to obtain top query hidden states, together with the previous outputs of normal # self-attention layers, a new attention layer will be attached to the output of the model @@ -853,33 +1098,18 @@ class PanguAlpha_Model(nn.Cell): self.use_top_query_attention = config.use_top_query_attention - def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None): + def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, + init_reset=True, batch_valid_length=None): """PanguAlpha model""" - if not self.use_past: - layer_past = self.past + # embedding for input_ids and the lower triangle like attention_mask matrix + hidden_states, attention_mask, embedding_table = self.embedding(input_ids, input_mask, + input_position, attention_mask, + batch_valid_length) - # Word embedding - input_embedding, embedding_table = self.word_embedding(input_ids) - # If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids - # and the corresponding input_position and attention_mask will be derived from it. - if not self.eod_reset: - batch_size, seq_length = F.shape(input_ids) - input_position = F.tuple_to_array(F.make_range(seq_length)) - input_position = P.Tile()(input_position, (batch_size, 1)) - attention_mask = self.get_attention_mask(input_mask) - position_embedding = self.position_embedding(input_position) - # Input features [bs, seq_length, embedding_size] - hidden_states = self.add(input_embedding, position_embedding) - hidden_states = self.dropout(hidden_states) - hidden_states = P.Cast()(hidden_states, mstype.float16) - attention_mask = self.expand_dims(attention_mask, 1) - - present_layer = () # Loop through each self-attention layer for i in range(self.num_layers): - hidden_states, present = self.blocks[i](hidden_states, - attention_mask, layer_past) - present_layer = present_layer + (present,) + hidden_states, _ = self.blocks[i](hidden_states, + attention_mask, init_reset, batch_valid_length) output_state = self.layernorm(hidden_states) output_state = F.cast(output_state, self.dtype) @@ -887,10 +1117,9 @@ class PanguAlpha_Model(nn.Cell): # Top query attention layer if self.use_top_query_attention: top_query_hidden_states = self.top_query_embedding(input_position) - output_state, present = self.top_query_layer(output_state, top_query_hidden_states, - attention_mask, layer_past) - present_layer = present_layer + (present,) - return output_state, present_layer, embedding_table + output_state, _ = self.top_query_layer(output_state, top_query_hidden_states, + attention_mask, init_reset, batch_valid_length) + return output_state, embedding_table class PanguAlpha_ModelPipeline(nn.Cell): """ @@ -922,7 +1151,7 @@ class PanguAlpha_ModelPipeline(nn.Cell): self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers per_block = QueryLayer(config).set_comm_fusion(2) else: - per_block = Block(config, i + 1).set_comm_fusion(2) + per_block = Decoder(config, i + 1).set_comm_fusion(2) per_block.pipeline_stage = i * config.stage_num // config.num_layers per_block.recompute() self.blocks.append(per_block) @@ -936,27 +1165,23 @@ class PanguAlpha_ModelPipeline(nn.Cell): self.layernorm.set_comm_fusion(2) self.layernorm.pipeline_stage = config.stage_num - 1 self.use_past = config.use_past - self.past = tuple([None] * config.num_layers) self.dtype = config.compute_dtype self.num_layers = config.num_layers - def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None): + def construct(self, input_ids, input_mask, table, input_position, attention_mask, + init_reset=True, batch_valid_length=None): """PanguAlpha model""" - if not self.use_past: - layer_past = self.past hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position) attention_mask = self.pangu_alpha_mask(input_mask, attention_mask) - present_layer = () for i in range(self.num_layers-1): - hidden_states, present = self.blocks[i](hidden_states, - attention_mask, layer_past) - present_layer = present_layer + (present,) + hidden_states, _ = self.blocks[i](hidden_states, + attention_mask, init_reset, batch_valid_length) + top_query_hidden_states = self.top_query_embedding(input_position) - hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states, - attention_mask, layer_past) - present_layer = present_layer + (present,) + hidden_states, present_layer = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states, + attention_mask, init_reset, batch_valid_length) output_state = self.layernorm(hidden_states) output_state = F.cast(output_state, self.dtype) return output_state, present_layer @@ -998,7 +1223,10 @@ class PanguAlpha(nn.Cell): Inputs: input_ids: the tokenized inputs input_mask: the mask indicating whether each position is a valid input - past: the previous feature map + input_position: the position index of each token + attention_mask: the attention_mask attention for self-attention module + init_reset: whether reset saved key and value states + batch_valid_length: the valid input sequence length without padding Returns: logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size) """ @@ -1009,9 +1237,10 @@ class PanguAlpha(nn.Cell): # Network head to get logits over vocabulary self.head = PanguAlpha_Head(config) - def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, past=None): - output_states, _, embedding_table = self.backbone( - input_ids, input_mask, input_position, attention_mask, past) + def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, + init_reset=True, batch_valid_length=None): + output_states, embedding_table = self.backbone( + input_ids, input_mask, input_position, attention_mask, init_reset, batch_valid_length) logits = self.head(output_states, embedding_table) return logits @@ -1039,9 +1268,10 @@ class PanguAlphaPipeline(nn.Cell): self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage) self.embedding_table.add_pipeline_stage(self.head.pipeline_stage) - def construct(self, input_ids, input_mask, input_position, attention_mask, past=None): + def construct(self, input_ids, input_mask, input_position, attention_mask, + init_reset=True, batch_valid_length=None): output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table, - input_position, attention_mask, past) + input_position, attention_mask, init_reset, batch_valid_length) logits = self.head(output_states, self.embedding_table) return logits @@ -1213,24 +1443,27 @@ class EvalNet(nn.Cell): Inputs: input_ids: the tokenized inpus current_index: the index of current token + init_reset: whether reset saved states Returns: outputs: Tensor, corresponding output for different tasks """ - def __init__(self, backbone, generate=False): + def __init__(self, backbone, generate=False, pad_token=6): super(EvalNet, self).__init__(auto_prefix=False) self.backbone = backbone + self.pad_token = pad_token self.argmax = P.Argmax() self.generate = generate self.topk = P.TopK(sorted=True).shard(((1, 1),)) - self.gather = P.GatherV2().shard(((1, 1, 1), (1,))) - self.log_softmax = P.LogSoftmax().shard(((1, 1),)) + self.gather = P.GatherV2().shard(((1, 1), (1,))) + self.log_softmax = P.LogSoftmax().shard(((1, 1, 1),)) - def construct(self, input_ids, current_index): + def construct(self, input_ids, current_index, init_reset=True, batch_valid_length=None): """evaluation net""" - input_mask = F.cast(F.not_equal(input_ids, 0), mstype.float32) + input_mask = F.cast(F.not_equal(input_ids, self.pad_token), mstype.float32) logits = self.backbone(input_ids, input_mask) - bs, seq_length = F.shape(input_ids) - logits = F.reshape(logits, (bs, seq_length, -1)) - logits = self.gather(logits, current_index, 1) + index = current_index.view(1,) + logits = self.gather(logits, index, 0) + bs, _ = F.shape(input_ids) + logits = logits.view(bs, 1, -1) log_probs = self.log_softmax(logits) return log_probs