pangu alpha modify
This commit is contained in:
parent
fb6ec96862
commit
327ba1962a
|
@ -1330,10 +1330,10 @@ class Cell(Cell_):
|
|||
Raises:
|
||||
RuntimeError: If there is a parameter does not belong to any stage.
|
||||
"""
|
||||
from mindspore.communication import get_group_size, get_rank
|
||||
from mindspore.parallel._utils import _get_global_rank, _get_device_num
|
||||
stage_num = context.get_auto_parallel_context("pipeline_stages")
|
||||
device_num = get_group_size()
|
||||
rank_id = get_rank()
|
||||
device_num = _get_device_num()
|
||||
rank_id = _get_global_rank()
|
||||
per_stage_devices = device_num // stage_num
|
||||
current_stage = rank_id // per_stage_devices
|
||||
params = []
|
||||
|
|
|
@ -20,7 +20,7 @@ from mindspore._checkparam import Validator
|
|||
from mindspore.common.dtype import pytype_to_dtype
|
||||
from .. import context, nn
|
||||
from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
|
||||
from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes, _get_pipeline_stages
|
||||
from ..ops import operations as P
|
||||
|
||||
|
||||
|
@ -158,7 +158,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|||
network = dataset_iter.__network_manage__[key]
|
||||
else:
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num()
|
||||
device_num = _get_device_num() // _get_pipeline_stages()
|
||||
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
||||
|
||||
network = _generate_dataset_sink_mode_net(network, dataset_shapes, dataset_types, queue_name)
|
||||
|
@ -358,7 +358,7 @@ class _DatasetIterGE(_DatasetIter):
|
|||
self.sink_count = self.get_sink_count(dataset)
|
||||
batch_expand_num = 1
|
||||
if _need_to_full():
|
||||
batch_expand_num = _get_device_num()
|
||||
batch_expand_num = _get_device_num() // _get_pipeline_stages()
|
||||
tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
|
||||
|
||||
def op():
|
||||
|
@ -393,7 +393,7 @@ class _DatasetIterMSLoopSink(_DatasetIter):
|
|||
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
|
||||
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
|
||||
if _need_to_full():
|
||||
device_num = _get_device_num()
|
||||
device_num = _get_device_num() // _get_pipeline_stages()
|
||||
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
|
||||
|
||||
def op():
|
||||
|
|
|
@ -98,7 +98,6 @@ def load_model(args_opt):
|
|||
dropout_rate=0.0,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=use_past,
|
||||
self_layernorm=True,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
eod_reset=False,
|
||||
|
|
|
@ -205,7 +205,6 @@ def generate_increment(model, origin_inputs, config):
|
|||
log_probs = logits.reshape(1, config.vocab_size)
|
||||
|
||||
# Get the revised log_probs considering frequency and presence penalty to eliminate duplicate in generated results
|
||||
log_probs = log_probs.asnumpy().reshape(1, config.vocab_size)
|
||||
log_probs_revised = log_probs - frequency_list * frequency_penalty - (frequency_list > 0) * presence_penalty
|
||||
|
||||
p, p_args = sampler(log_probs_revised, top_p, top_k_num, use_pynative)
|
||||
|
|
|
@ -220,51 +220,9 @@ class Output(nn.Cell):
|
|||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
class AttentionMask(nn.Cell):
|
||||
r"""
|
||||
Get the attention matrix for self-attention module
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
Returns:
|
||||
attention_mask: the attention mask matrix with shape (batch_size, 1, seq_length, seq_length)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(AttentionMask, self).__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.mul = P.BatchMatMul().shard(
|
||||
((config.dp, 1, 1), (config.dp, 1, 1))) # yzz: use 64, 1, 1?
|
||||
self.expand_dim = P.ExpandDims().shard(((1, 1),))
|
||||
ones = np.ones(shape=(config.seq_length, config.seq_length))
|
||||
# Default lower triangle mask matrix
|
||||
self.lower_triangle_mask = Tensor(np.tril(ones), mstype.float32)
|
||||
self.multiply = P.Mul().shard(((config.dp, 1, 1), (1, 1, 1)))
|
||||
|
||||
def construct(self, input_mask):
|
||||
r"""
|
||||
Generate the attention mask matrix.
|
||||
"""
|
||||
input_shape = P.Shape()(input_mask)
|
||||
shape_right = (input_shape[0], 1, input_shape[1])
|
||||
shape_left = input_shape + (1,)
|
||||
# Mask the padded inputs
|
||||
mask_left = self.reshape(input_mask, shape_left)
|
||||
mask_right = self.reshape(input_mask, shape_right)
|
||||
attention_mask = self.mul(mask_left, mask_right)
|
||||
lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
|
||||
# [bs, seq_length, seq_length]
|
||||
attention_mask = self.multiply(
|
||||
attention_mask, lower_traiangle)
|
||||
return attention_mask
|
||||
|
||||
|
||||
class EmbeddingLookup(nn.Cell):
|
||||
"""
|
||||
The embedding lookup table for vocabulary
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs with datatype int32
|
||||
Returns:
|
||||
|
@ -272,40 +230,11 @@ class EmbeddingLookup(nn.Cell):
|
|||
seq_length, embedding_size)
|
||||
self.embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
def __init__(self, config):
|
||||
def __init__(self):
|
||||
super(EmbeddingLookup, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embedding_size = config.embedding_size
|
||||
self.gather = P.GatherV2()
|
||||
|
||||
if config.word_emb_dp:
|
||||
self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1)))
|
||||
else:
|
||||
self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1)))
|
||||
self.shape = (-1, config.seq_length, config.embedding_size)
|
||||
if config.stage_num > 1:
|
||||
self.construct = self.construct_pipeline
|
||||
self.gather.add_prim_attr("parameter_start", 0)
|
||||
else:
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy')
|
||||
if os.path.exists(embedding_path):
|
||||
e_table = np.load(embedding_path)
|
||||
e_table = Tensor(e_table, mstype.float32)
|
||||
self.embedding_table = Parameter(e_table, name="embedding_table")
|
||||
else:
|
||||
raise ValueError(f"{embedding_path} file not exits, "
|
||||
f"please check whether word_embedding file exist.")
|
||||
else:
|
||||
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
|
||||
name="embedding_table")
|
||||
self.construct = self.construct_no_pipeline
|
||||
|
||||
def construct_no_pipeline(self, input_ids):
|
||||
output = self.gather(self.embedding_table, input_ids, 0)
|
||||
return output, self.embedding_table
|
||||
|
||||
def construct_pipeline(self, input_ids, table):
|
||||
def construct(self, input_ids, table):
|
||||
output = self.gather(table, input_ids, 0)
|
||||
return output
|
||||
|
||||
|
@ -320,8 +249,6 @@ class Attention(nn.Cell):
|
|||
"""
|
||||
def __init__(self, config, scale=1.0, layer_idx=None):
|
||||
super(Attention, self).__init__()
|
||||
# Attention mask matrix
|
||||
self.get_attention_mask = AttentionMask(config)
|
||||
# Output layer
|
||||
self.projection = Mapping(config, config.embedding_size,
|
||||
config.embedding_size, scale)
|
||||
|
@ -625,19 +552,8 @@ class Decoder(nn.Cell):
|
|||
def __init__(self, config, layer_idx):
|
||||
super(Decoder, self).__init__()
|
||||
scale = 1 / math.sqrt(2.0 * config.num_layers)
|
||||
|
||||
if config.self_layernorm:
|
||||
self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
else:
|
||||
self.layernorm1 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32)
|
||||
self.layernorm1.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
||||
self.layernorm2 = nn.LayerNorm((config.embedding_size,)).to_float(mstype.float32)
|
||||
self.layernorm2.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
||||
self.layernorm1.gamma.parallel_optimizer = False
|
||||
self.layernorm1.beta.parallel_optimizer = False
|
||||
self.layernorm2.gamma.parallel_optimizer = False
|
||||
self.layernorm2.beta.parallel_optimizer = False
|
||||
self.layernorm1 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
self.layernorm2 = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
|
||||
self.attention = Attention(config, scale, layer_idx)
|
||||
# Feed Forward Network, FFN
|
||||
|
@ -715,16 +631,40 @@ class Decoder(nn.Cell):
|
|||
output = self.add(x, mlp_logit)
|
||||
return output, layer_present
|
||||
|
||||
class PanguAlpha_EmbeddingPipeLine(nn.Cell):
|
||||
class EmbeddingCell(nn.Cell):
|
||||
"""
|
||||
PanguAlpha_EmbeddingPipeLine
|
||||
EmbeddingCell
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_EmbeddingPipeLine, self).__init__()
|
||||
self.word_embedding = EmbeddingLookup(config)
|
||||
self.position_embedding = nn.Embedding(config.seq_length,
|
||||
config.embedding_size,
|
||||
embedding_table=Normal(0.02))
|
||||
super(EmbeddingCell, self).__init__()
|
||||
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
|
||||
if config.word_emb_dp:
|
||||
self.word_embedding.gather.shard(((1, 1), (config.dp, 1)))
|
||||
else:
|
||||
self.word_embedding.gather.shard(((config.mp, 1), (1, 1)))
|
||||
if config.stage_num > 1:
|
||||
self.position_embedding = nn.Embedding(config.seq_length,
|
||||
config.embedding_size,
|
||||
embedding_table=Normal(0.02)).set_comm_fusion(1)
|
||||
else:
|
||||
# Position embedding
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
embedding_path = os.path.join(config.load_ckpt_path, 'position_embedding.npy')
|
||||
if os.path.exists(embedding_path):
|
||||
p_table = np.load(embedding_path)
|
||||
position_table_param = Tensor(p_table, mstype.float32)
|
||||
else:
|
||||
raise ValueError(f"{embedding_path} file not exits, "
|
||||
f"please check whether position_embedding file exit.")
|
||||
else:
|
||||
position_table_param = TruncatedNormal(0.02)
|
||||
# Position embedding
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.seq_length,
|
||||
config.embedding_size,
|
||||
embedding_table=position_table_param).set_comm_fusion(1)
|
||||
self.position_embedding.embedding_table.parallel_optimizer = False
|
||||
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.position_embedding.expand.shard(((config.dp, 1),))
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
|
@ -746,16 +686,15 @@ class PanguAlpha_EmbeddingPipeLine(nn.Cell):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class PanguAlpha_Mask(nn.Cell):
|
||||
class MaskCell(nn.Cell):
|
||||
"""
|
||||
PanguAlpha_Mask
|
||||
MaskCell
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_Mask, self).__init__()
|
||||
self.get_attention_mask = AttentionMask(config)
|
||||
super(MaskCell, self).__init__()
|
||||
self.dtype = config.compute_dtype
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
def construct(self, input_mask, attention_mask):
|
||||
def construct(self, attention_mask):
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
return attention_mask
|
||||
|
||||
|
@ -914,7 +853,7 @@ class QueryLayer(nn.Cell):
|
|||
output = self.add(x, mlp_logit)
|
||||
return output, layer_present
|
||||
|
||||
class Embedding(nn.Cell):
|
||||
class PanguAlphaEmbedding(nn.Cell):
|
||||
"""
|
||||
Input embedding, i.e., word embedding and position embedding
|
||||
Args:
|
||||
|
@ -931,174 +870,21 @@ class Embedding(nn.Cell):
|
|||
embedding_table: Tensor, embedding_table with shape of (vocab_size, embedding_size)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(Embedding, self).__init__()
|
||||
self.get_attention_mask = AttentionMask(config)
|
||||
# Word embedding
|
||||
self.word_embedding = EmbeddingLookup(config).set_comm_fusion(1)
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
embedding_path = os.path.join(config.load_ckpt_path, 'position_embedding.npy')
|
||||
if os.path.exists(embedding_path):
|
||||
p_table = np.load(embedding_path)
|
||||
position_table_param = Tensor(p_table, mstype.float32)
|
||||
else:
|
||||
raise ValueError(f"{embedding_path} file not exits, please check whether position_embedding file exit.")
|
||||
else:
|
||||
position_table_param = TruncatedNormal(0.02)
|
||||
super(PanguAlphaEmbedding, self).__init__()
|
||||
self.embedding = EmbeddingCell(config)
|
||||
if config.stage_num > 1:
|
||||
self.embedding.pipeline_stage = 0
|
||||
self.mask = MaskCell(config)
|
||||
|
||||
# Position embedding
|
||||
self.position_embedding = nn.Embedding(
|
||||
config.seq_length,
|
||||
config.embedding_size,
|
||||
embedding_table=position_table_param).set_comm_fusion(1)
|
||||
self.word_embedding.embedding_table.parallel_optimizer = False
|
||||
self.position_embedding.embedding_table.parallel_optimizer = False
|
||||
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.position_embedding.expand.shard(((config.dp, 1),))
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
self.dropout = Dropout(1 - config.dropout_rate)
|
||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||
self.eod_reset = config.eod_reset
|
||||
self.use_past = config.use_past
|
||||
self.is_first_iteration = True
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, valid_index=None):
|
||||
def construct(self, input_ids, input_mask, table, input_position, attention_mask, valid_index=None):
|
||||
"""
|
||||
Calculate input embeddings via input token ids and input position
|
||||
"""
|
||||
# Word embedding
|
||||
input_embedding, embedding_table = self.word_embedding(input_ids)
|
||||
# If eod_reset disabled, there will be only one input from the dataset, i.e., input_ids
|
||||
# and the corresponding input_position and attention_mask will be derived from it.
|
||||
if not self.eod_reset:
|
||||
batch_size, seq_length = F.shape(input_ids)
|
||||
attention_mask = self.get_attention_mask(input_mask)
|
||||
if self.use_past and not self.is_first_iteration:
|
||||
input_position = valid_index.view(1, seq_length)
|
||||
else:
|
||||
input_position = F.tuple_to_array(F.make_range(seq_length))
|
||||
input_position = P.Tile()(input_position, (batch_size, 1))
|
||||
position_embedding = self.position_embedding(input_position)
|
||||
# Input features [bs, seq_length, embedding_size]
|
||||
hidden_states = self.add(input_embedding, position_embedding)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = P.Cast()(hidden_states, mstype.float16)
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
return hidden_states, attention_mask, embedding_table
|
||||
|
||||
hidden_states = self.embedding(input_ids, table, input_position, valid_index)
|
||||
attention_mask = self.mask(attention_mask)
|
||||
return hidden_states, attention_mask
|
||||
|
||||
class PanguAlpha_Model(nn.Cell):
|
||||
"""
|
||||
The backbone of PanguAlpha network
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs with datatype int32
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
input_position: the position index of each token
|
||||
attention_mask: the attention_mask attention for self-attention module
|
||||
init_reset: whether reset saved key and value states
|
||||
batch_valid_length: the valid input sequence length without padding
|
||||
Returns:
|
||||
output_state: Tensor, the output logit of backbone
|
||||
embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_Model, self).__init__()
|
||||
self.embedding = Embedding(config)
|
||||
self.blocks = nn.CellList()
|
||||
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
|
||||
fusion_group_num = 4
|
||||
fusion_group_size = config.num_layers // fusion_group_num
|
||||
fusion_group_size = max(fusion_group_size, 1)
|
||||
|
||||
num_layers = config.num_layers
|
||||
# If top_query_attention enabled, replace the last normal self-attention layers with this top_query_attention layer
|
||||
if config.use_top_query_attention:
|
||||
num_layers -= 1
|
||||
self.num_layers = num_layers
|
||||
print("After setting the layer is:", num_layers, flush=True)
|
||||
|
||||
for i in range(num_layers):
|
||||
per_block = Decoder(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
|
||||
# in other words, in backward process each block will be almosttotally recomputed.
|
||||
if config.use_recompute:
|
||||
per_block.recompute()
|
||||
self.blocks.append(per_block)
|
||||
if config.self_layernorm:
|
||||
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(
|
||||
mstype.float32).set_comm_fusion(
|
||||
int((num_layers - 1) / fusion_group_size) + 2)
|
||||
else:
|
||||
self.layernorm = nn.LayerNorm((config.embedding_size,)).to_float(
|
||||
mstype.float32).set_comm_fusion(
|
||||
int((num_layers - 1) / fusion_group_size) + 2)
|
||||
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
||||
self.layernorm.gamma.parallel_optimizer = False
|
||||
self.layernorm.beta.parallel_optimizer = False
|
||||
self.use_past = config.use_past
|
||||
self.past = tuple([None] * config.num_layers)
|
||||
self.dtype = config.compute_dtype
|
||||
# If top_query_attention enabled, the input_position representing the position ids will be used as the index
|
||||
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
|
||||
# self-attention layers, a new attention layer will be attached to the output of the model
|
||||
if config.use_top_query_attention:
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
embedding_path = os.path.join(config.load_ckpt_path, 'top_query_embedding.npy')
|
||||
if os.path.exists(embedding_path):
|
||||
top_query_table = np.load(embedding_path)
|
||||
top_query_table_param = Tensor(top_query_table, mstype.float32)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{embedding_path} file not exits, please check whether top_query_embedding file exist.")
|
||||
else:
|
||||
top_query_table_param = TruncatedNormal(0.02)
|
||||
|
||||
self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
|
||||
embedding_table=top_query_table_param)
|
||||
# If the model is initialized with fp16, the fusion of layernorm (fp32 gradient) will mix up with
|
||||
# the bias parameter in linear models (fp16 gradient), causing dtype error for communication operators.
|
||||
# so we fuse communications of embedding to a large value(+100)
|
||||
self.top_query_embedding.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 200)
|
||||
self.top_query_embedding.embedding_table.parallel_optimizer = False
|
||||
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.top_query_embedding.expand.shard(((config.dp, 1),))
|
||||
self.top_query_layer = QueryLayer(config)
|
||||
if config.use_recompute:
|
||||
self.top_query_layer.recompute()
|
||||
self.top_query_layer.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
|
||||
|
||||
self.use_top_query_attention = config.use_top_query_attention
|
||||
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
"""PanguAlpha model"""
|
||||
# embedding for input_ids and the lower triangle like attention_mask matrix
|
||||
hidden_states, attention_mask, embedding_table = self.embedding(input_ids, input_mask,
|
||||
input_position, attention_mask,
|
||||
batch_valid_length)
|
||||
|
||||
# Loop through each self-attention layer
|
||||
for i in range(self.num_layers):
|
||||
hidden_states, _ = self.blocks[i](hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
|
||||
# Top query attention layer
|
||||
if self.use_top_query_attention:
|
||||
top_query_hidden_states = self.top_query_embedding(input_position)
|
||||
output_state, _ = self.top_query_layer(output_state, top_query_hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
return output_state, embedding_table
|
||||
|
||||
class PanguAlpha_ModelPipeline(nn.Cell):
|
||||
"""
|
||||
The backbone of PanguAlpha network
|
||||
Args:
|
||||
|
@ -1113,55 +899,92 @@ class PanguAlpha_ModelPipeline(nn.Cell):
|
|||
embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_ModelPipeline, self).__init__()
|
||||
self.pangu_alpha_embedding = PanguAlpha_EmbeddingPipeLine(config).set_comm_fusion(1)
|
||||
self.pangu_alpha_embedding.pipeline_stage = 0
|
||||
self.pangu_alpha_mask = PanguAlpha_Mask(config)
|
||||
super(PanguAlpha_Model, self).__init__()
|
||||
self.embedding = PanguAlphaEmbedding(config)
|
||||
self.blocks = nn.CellList()
|
||||
self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
|
||||
embedding_table=TruncatedNormal(0.02))
|
||||
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.top_query_embedding.expand.shard(((config.dp, 1),))
|
||||
for i in range(config.num_layers):
|
||||
if i == config.num_layers - 1:
|
||||
self.top_query_embedding.set_comm_fusion(2)
|
||||
self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers
|
||||
per_block = QueryLayer(config).set_comm_fusion(2)
|
||||
else:
|
||||
per_block = Decoder(config, i + 1).set_comm_fusion(2)
|
||||
per_block.pipeline_stage = i * config.stage_num // config.num_layers
|
||||
per_block.recompute()
|
||||
self.blocks.append(per_block)
|
||||
|
||||
if config.self_layernorm:
|
||||
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
else:
|
||||
self.layernorm = nn.LayerNorm(
|
||||
(config.embedding_size,)).to_float(mstype.float32)
|
||||
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
||||
self.layernorm.set_comm_fusion(2)
|
||||
self.layernorm.pipeline_stage = config.stage_num - 1
|
||||
self.use_past = config.use_past
|
||||
self.dtype = config.compute_dtype
|
||||
self.num_layers = config.num_layers
|
||||
self.is_pipeline = (config.stage_num > 1)
|
||||
if self.is_pipeline:
|
||||
self.top_query_embedding_table = Parameter(initializer(TruncatedNormal(0.02),
|
||||
[config.vocab_size, config.embedding_size]),
|
||||
name='embedding_table', parallel_optimizer=False)
|
||||
self.top_query_embedding = EmbeddingLookup()
|
||||
for i in range(config.num_layers):
|
||||
if i == config.num_layers - 1:
|
||||
self.top_query_embedding_table.comm_fusion = 2
|
||||
self.top_query_embedding_table.add_pipeline_stage(i * config.stage_num // config.num_layers)
|
||||
per_block = QueryLayer(config).set_comm_fusion(2)
|
||||
else:
|
||||
per_block = Decoder(config, i + 1).set_comm_fusion(2)
|
||||
per_block.pipeline_stage = i * config.stage_num // config.num_layers
|
||||
per_block.recompute()
|
||||
self.blocks.append(per_block)
|
||||
|
||||
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
self.layernorm.set_comm_fusion(2)
|
||||
self.layernorm.pipeline_stage = config.stage_num - 1
|
||||
else:
|
||||
# The input_position representing the position ids will be used as the index
|
||||
# for a query embedding table to obtain top query hidden states, together with the previous outputs of normal
|
||||
# self-attention layers, a new attention layer will be attached to the output of the model
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
embedding_path = os.path.join(config.load_ckpt_path, 'top_query_embedding.npy')
|
||||
if os.path.exists(embedding_path):
|
||||
top_query_table = np.load(embedding_path)
|
||||
top_query_table_param = Tensor(top_query_table, mstype.float32)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{embedding_path} file not exits, please check whether top_query_embedding file exist.")
|
||||
else:
|
||||
top_query_table_param = TruncatedNormal(0.02)
|
||||
self.top_query_embedding_table = Parameter(initializer(top_query_table_param,
|
||||
[config.vocab_size, config.embedding_size]),
|
||||
name='embedding_table', parallel_optimizer=False)
|
||||
self.top_query_embedding = EmbeddingLookup()
|
||||
# Total fusion groups for HCCL operators. Specifically, the same tyep HCCL operators in same group will be fused.
|
||||
fusion_group_num = 4
|
||||
fusion_group_size = config.num_layers // fusion_group_num
|
||||
fusion_group_size = max(fusion_group_size, 1)
|
||||
for i in range(config.num_layers):
|
||||
if i == config.num_layers - 1:
|
||||
per_block = QueryLayer(config).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||
else:
|
||||
per_block = Decoder(config, i + 1).set_comm_fusion(int(i / fusion_group_size) + 2)
|
||||
# Each layer will be remoputed in the backward process. The output activation of each layer will be saved,
|
||||
# in other words, in backward process each block will be almosttotally recomputed.
|
||||
per_block.recompute()
|
||||
self.blocks.append(per_block)
|
||||
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
self.layernorm.set_comm_fusion(int((config.num_layers - 1) / fusion_group_size) + 2)
|
||||
self.top_query_embedding_table.comm_fusion = int((config.num_layers - 1) / fusion_group_size) + 2
|
||||
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
|
||||
def construct(self, input_ids, input_mask, table, input_position, attention_mask,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
"""PanguAlpha model"""
|
||||
|
||||
hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position)
|
||||
attention_mask = self.pangu_alpha_mask(input_mask, attention_mask)
|
||||
|
||||
# embedding for input_ids and the lower triangle like attention_mask matrix
|
||||
hidden_states, attention_mask = self.embedding(input_ids, input_mask, table,
|
||||
input_position, attention_mask,
|
||||
batch_valid_length)
|
||||
for i in range(self.num_layers-1):
|
||||
hidden_states, _ = self.blocks[i](hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
|
||||
top_query_hidden_states = self.top_query_embedding(input_position)
|
||||
hidden_states, present_layer = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
return output_state, present_layer
|
||||
if self.is_pipeline:
|
||||
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,), self.top_query_embedding_table)
|
||||
hidden_states, _ = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
else:
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
top_query_hidden_states = self.top_query_embedding(input_position.view(-1,), self.top_query_embedding_table)
|
||||
output_state, _ = self.blocks[self.num_layers-1](output_state, top_query_hidden_states,
|
||||
attention_mask, init_reset, batch_valid_length)
|
||||
return output_state
|
||||
|
||||
class PanguAlpha_Head(nn.Cell):
|
||||
"""
|
||||
|
@ -1193,35 +1016,6 @@ class PanguAlpha_Head(nn.Cell):
|
|||
|
||||
|
||||
class PanguAlpha(nn.Cell):
|
||||
"""
|
||||
The PanguAlpha network consisting of two parts the backbone and the head
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
input_position: the position index of each token
|
||||
attention_mask: the attention_mask attention for self-attention module
|
||||
init_reset: whether reset saved key and value states
|
||||
batch_valid_length: the valid input sequence length without padding
|
||||
Returns:
|
||||
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha, self).__init__()
|
||||
# Network backbone of PanguAlpha
|
||||
self.backbone = PanguAlpha_Model(config)
|
||||
# Network head to get logits over vocabulary
|
||||
self.head = PanguAlpha_Head(config)
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position=None, attention_mask=None,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
output_states, embedding_table = self.backbone(
|
||||
input_ids, input_mask, input_position, attention_mask, init_reset, batch_valid_length)
|
||||
logits = self.head(output_states, embedding_table)
|
||||
return logits
|
||||
|
||||
class PanguAlphaPipeline(nn.Cell):
|
||||
"""
|
||||
The PanguAlpha network consisting of two parts the backbone and the head
|
||||
Args:
|
||||
|
@ -1234,21 +1028,38 @@ class PanguAlphaPipeline(nn.Cell):
|
|||
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlphaPipeline, self).__init__()
|
||||
self.backbone = PanguAlpha_ModelPipeline(config)
|
||||
super(PanguAlpha, self).__init__()
|
||||
# Network head to get logits over vocabulary
|
||||
self.head = PanguAlpha_Head(config)
|
||||
self.head.pipeline_stage = config.stage_num - 1
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embedding_size = config.embedding_size
|
||||
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
|
||||
name="embedding_table")
|
||||
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
|
||||
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
||||
self.backbone = PanguAlpha_Model(config)
|
||||
if config.stage_num > 1:
|
||||
|
||||
self.head.pipeline_stage = config.stage_num - 1
|
||||
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
|
||||
name="embedding_table", parallel_optimizer=False)
|
||||
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
|
||||
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
||||
else:
|
||||
if config.load_ckpt_path:
|
||||
# Loading the embedding table from the ckpt path:
|
||||
embedding_path = os.path.join(config.load_ckpt_path, 'word_embedding.npy')
|
||||
if os.path.exists(embedding_path):
|
||||
e_table = np.load(embedding_path)
|
||||
e_table = Tensor(e_table, mstype.float32)
|
||||
self.embedding_table = Parameter(e_table, name="embedding_table", parallel_optimizer=False)
|
||||
else:
|
||||
raise ValueError(f"{embedding_path} file not exits, "
|
||||
f"please check whether word_embedding file exist.")
|
||||
else:
|
||||
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
|
||||
name="embedding_table", parallel_optimizer=False)
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position, attention_mask,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table,
|
||||
input_position, attention_mask, init_reset, batch_valid_length)
|
||||
output_states = self.backbone(input_ids, input_mask, self.embedding_table,
|
||||
input_position, attention_mask, init_reset, batch_valid_length)
|
||||
logits = self.head(output_states, self.embedding_table)
|
||||
return logits
|
||||
|
||||
|
@ -1324,7 +1135,6 @@ class CrossEntropyLoss(nn.Cell):
|
|||
loss = self.div2(numerator, denominator)
|
||||
return loss
|
||||
|
||||
|
||||
class PanguAlphaWithLoss(nn.Cell):
|
||||
"""
|
||||
PanguAlpha training loss
|
||||
|
@ -1342,65 +1152,14 @@ class PanguAlphaWithLoss(nn.Cell):
|
|||
super(PanguAlphaWithLoss, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.loss = loss
|
||||
# id for end_of_sentence, 6 in the vocabulary
|
||||
self.eos_token = eos_token
|
||||
self.slice = P.StridedSlice().shard(((config.dp, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
|
||||
self.batch_size = config.batch_size
|
||||
self.len = config.seq_length
|
||||
self.eod_reset = config.eod_reset
|
||||
if self.eod_reset:
|
||||
self.slice_mask = P.StridedSlice().shard(((config.dp, 1, 1),))
|
||||
|
||||
def construct(self, input_ids, input_position=None, attention_mask=None):
|
||||
r"""
|
||||
PanguAlphaWithLoss
|
||||
"""
|
||||
# input_ids [bs, seq_length+1]
|
||||
# input_position [bs, seq_length] only available when eod_reset enabled
|
||||
# attention_mask [bs, seq_length, seq_length] only available when eod-reset enabled
|
||||
# Get input tokens [bs, seq_length]
|
||||
tokens = self.slice(input_ids, (0, 0), (self.batch_size, -1), (1, 1))
|
||||
|
||||
if self.eod_reset:
|
||||
input_position = self.slice(input_position, (0, 0), (self.batch_size, self.len), (1, 1))
|
||||
attention_mask = self.slice_mask(attention_mask, (0, 0, 0),
|
||||
(self.batch_size, self.len, self.len),
|
||||
(1, 1, 1))
|
||||
# Check whether there is padding in inputs
|
||||
input_mask = F.cast(self.not_equal(tokens, self.eos_token),
|
||||
mstype.float32)
|
||||
logits = self.network(tokens, input_mask, input_position, attention_mask)
|
||||
# Get label corresponding to input tokens
|
||||
labels = self.slice(input_ids, (0, 1), (self.batch_size, self.len + 1),
|
||||
(1, 1))
|
||||
# Loss
|
||||
output = self.loss(logits, labels, input_mask)
|
||||
return output
|
||||
|
||||
class PanguAlphaWithLossPipeline(nn.Cell):
|
||||
"""
|
||||
PanguAlpha training loss
|
||||
Args:
|
||||
network: backbone network of PanguAlpha
|
||||
loss: loss function, e.g., crossentropy
|
||||
eos_token: the end_of_sentence token
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs
|
||||
past: the previous feature map
|
||||
Returns:
|
||||
output: Tensor, the loss of the network
|
||||
"""
|
||||
def __init__(self, config, network, loss, eos_token=6):
|
||||
super(PanguAlphaWithLossPipeline, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.loss = loss
|
||||
self.eos_token = eos_token
|
||||
self.slice = P.StridedSlice().shard(((config.dp, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
|
||||
self.batch_size = config.batch_size
|
||||
self.len = config.seq_length
|
||||
self.micro_batch_step = config.micro_size
|
||||
self.micro_batch_step = 1
|
||||
if config.stage_num > 1:
|
||||
self.micro_batch_step = config.micro_size
|
||||
|
||||
def construct(self, input_ids, input_position, attention_mask):
|
||||
tokens = self.slice(input_ids, (0, 0), (self.batch_size // self.micro_batch_step, -1), (1, 1))
|
||||
|
|
|
@ -36,15 +36,13 @@ class PANGUALPHAConfig:
|
|||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False,
|
||||
self_layernorm=True,
|
||||
word_emb_dp=True,
|
||||
stage_num=16,
|
||||
eod_reset=True,
|
||||
micro_size=32,
|
||||
load_ckpt_path=None,
|
||||
use_top_query_attention=True,
|
||||
param_init_type=mstype.float32,
|
||||
use_recompute=True):
|
||||
param_init_type=mstype.float32):
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -61,8 +59,6 @@ class PANGUALPHAConfig:
|
|||
self.use_past = use_past
|
||||
self.dp = data_parallel_num
|
||||
self.mp = model_parallel_num
|
||||
# Whether use self implemented layernorm
|
||||
self.self_layernorm = self_layernorm
|
||||
self.stage_num = stage_num
|
||||
self.micro_size = micro_size
|
||||
self.word_emb_dp = word_emb_dp
|
||||
|
@ -70,7 +66,6 @@ class PANGUALPHAConfig:
|
|||
# Used for loading embedding tables
|
||||
self.load_ckpt_path = load_ckpt_path
|
||||
self.use_top_query_attention = use_top_query_attention
|
||||
self.use_recompute = use_recompute
|
||||
self.param_init_type = param_init_type
|
||||
|
||||
def __str__(self):
|
||||
|
@ -89,7 +84,8 @@ def set_parse(args_opt):
|
|||
args_opt.embedding_size = 16384
|
||||
args_opt.num_layers = 64
|
||||
args_opt.num_heads = 128
|
||||
args_opt.per_batch_size = 1
|
||||
if args_opt.per_batch_size == 0:
|
||||
args_opt.per_batch_size = 1
|
||||
args_opt.word_emb_dp = 0
|
||||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 6e-5
|
||||
|
@ -117,9 +113,14 @@ def set_parse(args_opt):
|
|||
args_opt.optimizer_shard = 1
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
args_opt.full_batch = 0
|
||||
if args_opt.per_batch_size == 0:
|
||||
args_opt.per_batch_size = 8
|
||||
elif args_opt.run_type == "predict":
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
if args_opt.per_batch_size == 0:
|
||||
args_opt.per_batch_size = 1
|
||||
elif args_opt.mode == "2.6B":
|
||||
args_opt.embedding_size = 2560
|
||||
args_opt.num_layers = 32
|
||||
|
@ -131,6 +132,11 @@ def set_parse(args_opt):
|
|||
args_opt.optimizer_shard = 1
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
args_opt.full_batch = 0
|
||||
if args_opt.per_batch_size == 0:
|
||||
args_opt.per_batch_size = 8
|
||||
elif args_opt.run_type == "predict":
|
||||
args_opt.stage_num = 1
|
||||
args_opt.micro_size = 1
|
||||
if args_opt.per_batch_size == 0:
|
||||
args_opt.per_batch_size = 1
|
||||
|
|
|
@ -25,6 +25,7 @@ from mindspore import context, Parameter
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.parallel._utils import _get_enable_parallel_optimizer
|
||||
from src.utils import ClipByGlobalNorm
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
|
@ -61,6 +62,7 @@ def _clip_grad(clip_type, clip_value, grad):
|
|||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
shard_grad_scale = C.MultitypeFuncGraph("shard_grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
|
@ -77,6 +79,13 @@ def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
|||
_ = F.assign(accu_grad, zeros)
|
||||
return new_grad
|
||||
|
||||
@shard_grad_scale.register("Tensor", "Tensor", "Tensor")
|
||||
def tensor_shard_grad_scale_pipeline(scale, grad, accu_grad):
|
||||
new_grad = grad * reciprocal(scale)
|
||||
accu_grad = F.depend(accu_grad, new_grad)
|
||||
_ = F.assign(accu_grad, F.zeros_like(accu_grad))
|
||||
return new_grad
|
||||
|
||||
class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
Encapsulation class of PanguAlpha network training.
|
||||
|
@ -106,7 +115,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
|||
self.clip = ClipByGlobalNorm(self.weights, config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
|
||||
def construct(self, input_ids, input_position, attention_mask, layer_past=None, sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
# Forward process
|
||||
|
@ -194,6 +203,7 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
|
|||
name="loss_scale")
|
||||
self.clip = ClipByGlobalNorm(self.weights, self.config)
|
||||
self.micro_size = config.micro_size
|
||||
self.opt_shard = _get_enable_parallel_optimizer()
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
|
@ -224,8 +234,12 @@ class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
|
|||
flag_sum = self.reduce_sum(init, (0,))
|
||||
loss = F.depend(loss, status_clear)
|
||||
# apply grad reducer on grads
|
||||
accu_grads = self.grad_reducer(self.accu_grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
|
||||
if self.opt_shard:
|
||||
grads = self.grad_reducer(grads)
|
||||
grads = self.hyper_map(F.partial(shard_grad_scale, scaling_sens * self.degree), grads, self.accu_grads)
|
||||
else:
|
||||
accu_grads = self.grad_reducer(self.accu_grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
|
||||
if self.enable_global_norm:
|
||||
grads, _ = self.clip(grads)
|
||||
else:
|
||||
|
|
|
@ -330,7 +330,7 @@ def add_training_params(opt):
|
|||
help="Enable optimizer parallel, default is 1")
|
||||
opt.add_argument("--per_batch_size",
|
||||
type=int,
|
||||
default=6,
|
||||
default=0,
|
||||
help="The batch size for each data parallel way. default 6")
|
||||
opt.add_argument("--start_lr",
|
||||
type=float,
|
||||
|
|
|
@ -31,8 +31,7 @@ from mindspore.parallel import set_algo_parameters
|
|||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
|
||||
from src.dataset import create_dataset
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss,\
|
||||
PanguAlphaPipeline, PanguAlphaWithLossPipeline, CrossEntropyLoss
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
|
||||
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
|
||||
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
||||
|
@ -196,10 +195,7 @@ def run_train_pipeline(args_opt):
|
|||
The main training process in pipeline.
|
||||
"""
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
context.set_context(save_graphs=False,
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
context.set_context(variable_memory_max_size="31GB")
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
|
@ -210,7 +206,7 @@ def run_train_pipeline(args_opt):
|
|||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
gradients_mean=False,
|
||||
device_num=device_num,
|
||||
full_batch=True,
|
||||
full_batch=bool(args_opt.full_batch),
|
||||
loss_repeated_mean=True,
|
||||
enable_parallel_optimizer=bool(args_opt.optimizer_shard),
|
||||
pipeline_stages=args_opt.stage_num)
|
||||
|
@ -219,6 +215,12 @@ def run_train_pipeline(args_opt):
|
|||
else:
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
device_num = 1
|
||||
# copy data from the cloud to the /cache/Data
|
||||
cache_url = '/cache/Data/'
|
||||
if args_opt.offline:
|
||||
cache_url = args_opt.data_url
|
||||
else:
|
||||
download_data(src_data_url=args_opt.data_url, tgt_data_path=cache_url, rank=rank_id)
|
||||
model_parallel_num = args_opt.op_level_model_parallel_num
|
||||
stage_device_num = int(device_num / args_opt.stage_num)
|
||||
data_parallel_num = int(stage_device_num / model_parallel_num)
|
||||
|
@ -238,20 +240,17 @@ def run_train_pipeline(args_opt):
|
|||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False,
|
||||
self_layernorm=True,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
word_emb_dp=bool(args_opt.word_emb_dp))
|
||||
print("===config is: ", config, flush=True)
|
||||
pangu_alpha = PanguAlphaPipeline(config)
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLossPipeline(config, pangu_alpha, loss), config.micro_size)
|
||||
pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLoss(config, pangu_alpha, loss), config.micro_size)
|
||||
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss)
|
||||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr,
|
||||
end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step,
|
||||
decay_steps=args_opt.decay_steps)
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps)
|
||||
params = pangu_alpha.infer_param_pipeline_stage()
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
decay_params = list(filter(decay_filter, params))
|
||||
|
@ -269,32 +268,33 @@ def run_train_pipeline(args_opt):
|
|||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||
ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True,
|
||||
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
|
||||
if context.get_auto_parallel_context("full_batch"):
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, eod_reset=True,
|
||||
data_start_index=0, full_batch=True, column_name=args_opt.data_column_name)
|
||||
else:
|
||||
if batch_size % stage_device_num != 0:
|
||||
raise ValueError("Batch_size should be divisible by device_num")
|
||||
ds = create_dataset(config.batch_size, data_path=cache_url, device_num=stage_device_num,
|
||||
rank=rank_id, eod_reset=True, data_start_index=0, full_batch=False,
|
||||
column_name=args_opt.data_column_name)
|
||||
epoch_num = args_opt.epoch_size
|
||||
step_per_epoch = ds.get_dataset_size()
|
||||
callback_size = args_opt.sink_size
|
||||
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
||||
callback = [
|
||||
TimeMonitor(callback_size),
|
||||
LossCallBack(callback_size, rank_id, config.stage_num)
|
||||
]
|
||||
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num)]
|
||||
loss_scale_value = math.pow(2, 32)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
|
||||
scale_factor=2,
|
||||
scale_window=1000)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
||||
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
||||
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
|
||||
model = Model(pangu_alpha_with_grads)
|
||||
model.train(actual_epoch_num,
|
||||
ds,
|
||||
callbacks=callback,
|
||||
sink_size=callback_size,
|
||||
dataset_sink_mode=True)
|
||||
model.train(actual_epoch_num, ds, callbacks=callback,
|
||||
sink_size=callback_size, dataset_sink_mode=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = get_args()
|
||||
set_parse(opt)
|
||||
if opt.per_batch_size == 0:
|
||||
raise ValueError("The per_batch_size has not been configured.")
|
||||
if opt.stage_num > 1:
|
||||
run_train_pipeline(opt)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue