forked from mindspore-Ecosystem/mindspore
!18905 pangu_pipeline
Merge pull request !18905 from yao_yf/pangu_pipeline
This commit is contained in:
commit
abc77e13e9
|
@ -225,8 +225,9 @@ bool CompFunc(const AnfNodePtr &node1, const AnfNodePtr &node2) {
|
|||
if (rank_tag2 == nullptr) {
|
||||
rank_tag2 = prim2->GetAttr(DEST_RANK);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(rank_tag1);
|
||||
MS_EXCEPTION_IF_NULL(rank_tag2);
|
||||
if (!rank_tag1 || !rank_tag2) {
|
||||
return false;
|
||||
}
|
||||
auto rank1_value = GetValue<int64_t>(rank_tag1);
|
||||
auto rank2_value = GetValue<int64_t>(rank_tag2);
|
||||
if (rank1_value == rank2_value) {
|
||||
|
|
|
@ -136,7 +136,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
|
|||
// define the parse constant
|
||||
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
|
||||
const char CUSTOM_BPROP_NAME[] = "bprop";
|
||||
const char STAGE_NAME[] = "pipeline_stage";
|
||||
const char STAGE_NAME[] = "_pipeline_stage";
|
||||
|
||||
// define the Namespace name
|
||||
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace
|
||||
|
|
|
@ -154,6 +154,7 @@ class Parameter(Tensor_):
|
|||
self._cast_type = None
|
||||
self._unique = False
|
||||
self.is_in_parallel = _is_in_parallel_mode()
|
||||
self._pipeline_stage_list = []
|
||||
if isinstance(default_input, (Tensor_, Tensor)):
|
||||
Tensor_.__init__(self, default_input.dtype, default_input.shape)
|
||||
elif isinstance(default_input, int):
|
||||
|
@ -452,6 +453,9 @@ class Parameter(Tensor_):
|
|||
new_param.param_info = self.param_info
|
||||
return new_param
|
||||
|
||||
def add_pipeline_stage(self, stage):
|
||||
self._pipeline_stage_list.append(stage)
|
||||
|
||||
def set_data(self, data, slice_shape=False):
|
||||
"""
|
||||
Set Parameter's data.
|
||||
|
|
|
@ -230,6 +230,18 @@ class Cell(Cell_):
|
|||
raise TypeError("'parallel_parameter_name_list' must be list type.")
|
||||
self._parallel_parameter_name_list = value
|
||||
|
||||
@property
|
||||
def pipeline_stage(self):
|
||||
return self._pipeline_stage
|
||||
|
||||
@pipeline_stage.setter
|
||||
def pipeline_stage(self, value):
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("'pipeline_stage' must be int type.")
|
||||
self._pipeline_stage = value
|
||||
for item in self.trainable_params():
|
||||
item.add_pipeline_stage(value)
|
||||
|
||||
@property
|
||||
def parallel_parameter_merge_net_dict(self):
|
||||
return self._parallel_parameter_merge_net_dict
|
||||
|
@ -1297,6 +1309,41 @@ class Cell(Cell_):
|
|||
for cell in self.cells():
|
||||
cell.recompute(mode, True)
|
||||
|
||||
def infer_param_pipeline_stage(self):
|
||||
"""
|
||||
Infer pipeline stages of all parameters in the cell.
|
||||
|
||||
Notes:
|
||||
- If a parameter does not belong to any cell which has been set pipeline_stage,
|
||||
the parameter should use add_pipeline_stage to add it's pipeline_stage information.
|
||||
- If a parameter P has been used by two operator in different stages "stageA" and "stageB",
|
||||
the parameter P should use P.add_pipeline_stage(stageA) and P.add_pipeline_stage(stageB)
|
||||
to add it's stage information before use infer_param_pipeline_stage.
|
||||
|
||||
Returns:
|
||||
The params belong to current stage in pipeline parallel.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If there is a parameter does not belong to any stage.
|
||||
"""
|
||||
from mindspore.communication import get_group_size, get_rank
|
||||
stage_num = context.get_auto_parallel_context("pipeline_stages")
|
||||
device_num = get_group_size()
|
||||
rank_id = get_rank()
|
||||
per_stage_devices = device_num // stage_num
|
||||
current_stage = rank_id // per_stage_devices
|
||||
params = []
|
||||
for param in self.trainable_params():
|
||||
if not param._pipeline_stage_list:
|
||||
raise RuntimeError("The parameter {} does not belong to any stage, "
|
||||
"please check whether the cell where the param locates"
|
||||
" has been set pipeline_stage. "
|
||||
"Otherwise, the parameter should use add_pipeline_stage "
|
||||
"to add its stage information".format(param.name))
|
||||
if current_stage in param._pipeline_stage_list:
|
||||
params.append(param)
|
||||
return params
|
||||
|
||||
|
||||
class GraphKernel(Cell):
|
||||
"""
|
||||
|
|
|
@ -302,6 +302,32 @@ class EmbeddingLookup(nn.Cell):
|
|||
return output, self.embedding_table
|
||||
|
||||
|
||||
class EmbeddingLookupPipeline(nn.Cell):
|
||||
"""
|
||||
The embedding lookup table for vocabulary
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs with datatype int32
|
||||
Returns:
|
||||
output: Tensor, the embedding vector for the input with shape (batch_size, seq_length, embedding_size)
|
||||
self.embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(EmbeddingLookupPipeline, self).__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embedding_size = config.embedding_size
|
||||
if config.word_emb_dp:
|
||||
self.gather = P.GatherV2().shard(((1, 1), (config.dp, 1)))
|
||||
else:
|
||||
self.gather = P.GatherV2().shard(((config.mp, 1), (1, 1)))
|
||||
self.gather.add_prim_attr("parameter_start", 0)
|
||||
self.shape = (-1, config.seq_length, config.embedding_size)
|
||||
|
||||
def construct(self, input_ids, table):
|
||||
output = self.gather(table, input_ids, 0)
|
||||
return output
|
||||
|
||||
class Attention(nn.Cell):
|
||||
"""
|
||||
Self-Attention module for each layer
|
||||
|
@ -593,6 +619,44 @@ class Block(nn.Cell):
|
|||
output = self.add(x, mlp_logit)
|
||||
return output, layer_present
|
||||
|
||||
class PanguAlpha_EmbeddingPipeLine(nn.Cell):
|
||||
"""
|
||||
PanguAlpha_EmbeddingPipeLine
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_EmbeddingPipeLine, self).__init__()
|
||||
self.word_embedding = EmbeddingLookupPipeline(config)
|
||||
self.position_embedding = nn.Embedding(config.seq_length,
|
||||
config.embedding_size,
|
||||
embedding_table=Normal(0.02))
|
||||
self.position_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.position_embedding.expand.shard(((config.dp, 1),))
|
||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||
self.dropout = Dropout(1 - config.dropout_rate)
|
||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
||||
|
||||
def construct(self, input_ids, table, input_position):
|
||||
input_embedding = self.word_embedding(input_ids, table)
|
||||
position_embedding = self.position_embedding(input_position)
|
||||
hidden_states = self.add(input_embedding, position_embedding)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = P.Cast()(hidden_states, mstype.float16)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PanguAlpha_Mask(nn.Cell):
|
||||
"""
|
||||
PanguAlpha_Mask
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_Mask, self).__init__()
|
||||
self.get_attention_mask = AttentionMask(config)
|
||||
self.dtype = config.compute_dtype
|
||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||
def construct(self, input_mask, attention_mask):
|
||||
attention_mask = self.expand_dims(attention_mask, 1)
|
||||
return attention_mask
|
||||
|
||||
class QueryLayerAttention(Attention):
|
||||
r"""
|
||||
|
@ -828,6 +892,74 @@ class PanguAlpha_Model(nn.Cell):
|
|||
present_layer = present_layer + (present,)
|
||||
return output_state, present_layer, embedding_table
|
||||
|
||||
class PanguAlpha_ModelPipeline(nn.Cell):
|
||||
"""
|
||||
The backbone of PanguAlpha network
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs with datatype int32
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
layer_past: the previous feature map
|
||||
Returns:
|
||||
output_state: Tensor, the output logit of backbone
|
||||
present_layer: Tensor, the current feature map
|
||||
embedding_table: Tensor, the embedding table for the vocabulary
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlpha_ModelPipeline, self).__init__()
|
||||
self.pangu_alpha_embedding = PanguAlpha_EmbeddingPipeLine(config).set_comm_fusion(1)
|
||||
self.pangu_alpha_embedding.pipeline_stage = 0
|
||||
self.pangu_alpha_mask = PanguAlpha_Mask(config)
|
||||
self.blocks = nn.CellList()
|
||||
self.top_query_embedding = nn.Embedding(config.seq_length, config.embedding_size,
|
||||
embedding_table=TruncatedNormal(0.02))
|
||||
self.top_query_embedding.gather.shard(((1, 1), (config.dp,)))
|
||||
self.top_query_embedding.expand.shard(((config.dp, 1),))
|
||||
for i in range(config.num_layers):
|
||||
if i == config.num_layers - 1:
|
||||
self.top_query_embedding.set_comm_fusion(2)
|
||||
self.top_query_embedding.pipeline_stage = i * config.stage_num // config.num_layers
|
||||
per_block = QueryLayer(config).set_comm_fusion(2)
|
||||
else:
|
||||
per_block = Block(config, i + 1).set_comm_fusion(2)
|
||||
per_block.pipeline_stage = i * config.stage_num // config.num_layers
|
||||
per_block.recompute()
|
||||
self.blocks.append(per_block)
|
||||
|
||||
if config.self_layernorm:
|
||||
self.layernorm = LayerNorm((config.embedding_size,), config.dp).to_float(mstype.float32)
|
||||
else:
|
||||
self.layernorm = nn.LayerNorm(
|
||||
(config.embedding_size,)).to_float(mstype.float32)
|
||||
self.layernorm.layer_norm.shard(((config.dp, 1, 1), (1,), (1,)))
|
||||
self.layernorm.set_comm_fusion(2)
|
||||
self.layernorm.pipeline_stage = config.stage_num - 1
|
||||
self.use_past = config.use_past
|
||||
self.past = tuple([None] * config.num_layers)
|
||||
self.dtype = config.compute_dtype
|
||||
self.num_layers = config.num_layers
|
||||
|
||||
def construct(self, input_ids, input_mask, table, input_position, attention_mask, layer_past=None):
|
||||
"""PanguAlpha model"""
|
||||
if not self.use_past:
|
||||
layer_past = self.past
|
||||
|
||||
hidden_states = self.pangu_alpha_embedding(input_ids, table, input_position)
|
||||
attention_mask = self.pangu_alpha_mask(input_mask, attention_mask)
|
||||
|
||||
present_layer = ()
|
||||
for i in range(self.num_layers-1):
|
||||
hidden_states, present = self.blocks[i](hidden_states,
|
||||
attention_mask, layer_past)
|
||||
present_layer = present_layer + (present,)
|
||||
top_query_hidden_states = self.top_query_embedding(input_position)
|
||||
hidden_states, present = self.blocks[self.num_layers-1](hidden_states, top_query_hidden_states,
|
||||
attention_mask, layer_past)
|
||||
present_layer = present_layer + (present,)
|
||||
output_state = self.layernorm(hidden_states)
|
||||
output_state = F.cast(output_state, self.dtype)
|
||||
return output_state, present_layer
|
||||
|
||||
class PanguAlpha_Head(nn.Cell):
|
||||
"""
|
||||
|
@ -883,6 +1015,35 @@ class PanguAlpha(nn.Cell):
|
|||
logits = self.head(output_states, embedding_table)
|
||||
return logits
|
||||
|
||||
class PanguAlphaPipeline(nn.Cell):
|
||||
"""
|
||||
The PanguAlpha network consisting of two parts the backbone and the head
|
||||
Args:
|
||||
config(PanguAlphaConfig): the config of network
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs
|
||||
input_mask: the mask indicating whether each position is a valid input
|
||||
past: the previous feature map
|
||||
Returns:
|
||||
logits: Tensor: the logits of the corresponding inputs with shape (batch_size, seq_length, vocab_size)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(PanguAlphaPipeline, self).__init__()
|
||||
self.backbone = PanguAlpha_ModelPipeline(config)
|
||||
self.head = PanguAlpha_Head(config)
|
||||
self.head.pipeline_stage = config.stage_num - 1
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embedding_size = config.embedding_size
|
||||
self.embedding_table = Parameter(initializer(Normal(0.02), [self.vocab_size, self.embedding_size]),
|
||||
name="embedding_table")
|
||||
self.embedding_table.add_pipeline_stage(self.backbone.blocks[0].pipeline_stage)
|
||||
self.embedding_table.add_pipeline_stage(self.head.pipeline_stage)
|
||||
|
||||
def construct(self, input_ids, input_mask, input_position, attention_mask, past=None):
|
||||
output_states, _ = self.backbone(input_ids, input_mask, self.embedding_table,
|
||||
input_position, attention_mask, past)
|
||||
logits = self.head(output_states, self.embedding_table)
|
||||
return logits
|
||||
|
||||
class CrossEntropyLoss(nn.Cell):
|
||||
"""
|
||||
|
@ -1010,6 +1171,38 @@ class PanguAlphaWithLoss(nn.Cell):
|
|||
output = self.loss(logits, labels, input_mask)
|
||||
return output
|
||||
|
||||
class PanguAlphaWithLossPipeline(nn.Cell):
|
||||
"""
|
||||
PanguAlpha training loss
|
||||
Args:
|
||||
network: backbone network of PanguAlpha
|
||||
loss: loss function, e.g., crossentropy
|
||||
eos_token: the end_of_sentence token
|
||||
Inputs:
|
||||
input_ids: the tokenized inputs
|
||||
past: the previous feature map
|
||||
Returns:
|
||||
output: Tensor, the loss of the network
|
||||
"""
|
||||
def __init__(self, config, network, loss, eos_token=6):
|
||||
super(PanguAlphaWithLossPipeline, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.loss = loss
|
||||
self.eos_token = eos_token
|
||||
self.slice = P.StridedSlice().shard(((config.dp, 1),))
|
||||
self.not_equal = P.NotEqual().shard(((config.dp, 1), ()))
|
||||
self.batch_size = config.batch_size
|
||||
self.len = config.seq_length
|
||||
self.micro_batch_step = config.micro_size
|
||||
|
||||
def construct(self, input_ids, input_position, attention_mask):
|
||||
tokens = self.slice(input_ids, (0, 0), (self.batch_size // self.micro_batch_step, -1), (1, 1))
|
||||
input_mask = F.cast(self.not_equal(tokens, self.eos_token), mstype.float32)
|
||||
logits = self.network(tokens, input_mask, input_position, attention_mask)
|
||||
labels = self.slice(input_ids, (0, 1), (self.batch_size // self.micro_batch_step,
|
||||
self.len + 1), (1, 1))
|
||||
output = self.loss(logits, labels, input_mask)
|
||||
return output
|
||||
|
||||
class EvalNet(nn.Cell):
|
||||
"""
|
||||
|
|
|
@ -89,20 +89,27 @@ def set_parse(args_opt):
|
|||
args_opt.embedding_size = 16384
|
||||
args_opt.num_layers = 64
|
||||
args_opt.num_heads = 128
|
||||
args_opt.per_batch_size = 1
|
||||
args_opt.word_emb_dp = 0
|
||||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 6e-5
|
||||
args_opt.end_lr = 6e-6
|
||||
args_opt.optimizer_shard = 0
|
||||
args_opt.stage_num = 16
|
||||
args_opt.micro_size = 32
|
||||
args_opt.op_level_model_parallel_num = 16
|
||||
if args_opt.optimizer_shard == 1:
|
||||
args_opt.op_level_model_parallel_num = 8
|
||||
elif args_opt.run_type == "predict":
|
||||
args_opt.stage_num = 4
|
||||
args_opt.micro_size = 1
|
||||
args_opt.op_level_model_parallel_num = 16
|
||||
if args_opt.optimizer_shard == 1:
|
||||
args_opt.op_level_model_parallel_num = 8
|
||||
elif args_opt.mode == "13B":
|
||||
args_opt.embedding_size = 5120
|
||||
args_opt.num_layers = 40
|
||||
args_opt.num_heads = 40
|
||||
args_opt.word_emb_dp = 1
|
||||
args_opt.op_level_model_parallel_num = 8
|
||||
if args_opt.run_type == "train":
|
||||
args_opt.start_lr = 5e-5
|
||||
|
|
|
@ -20,8 +20,11 @@ from mindspore.ops import composite as C
|
|||
from mindspore.ops import functional as F
|
||||
from mindspore.common.tensor import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
||||
from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell
|
||||
from mindspore import context, Parameter
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.communication.management import get_group_size
|
||||
from src.utils import ClipByGlobalNorm
|
||||
|
||||
GRADIENT_CLIP_TYPE = 1
|
||||
|
@ -65,16 +68,14 @@ reciprocal = P.Reciprocal()
|
|||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class VirtualDatasetOneInputCell(nn.Cell):
|
||||
def __init__(self, backbone):
|
||||
super(VirtualDatasetOneInputCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._virtual_dataset = _VirtualDataset()
|
||||
|
||||
def construct(self, *data):
|
||||
data_ = self._virtual_dataset(*data)
|
||||
return self._backbone(*data_)
|
||||
@grad_scale.register("Tensor", "Tensor", "Tensor")
|
||||
def tensor_grad_scale_pipeline(scale, grad, accu_grad):
|
||||
accu_grad = F.depend(accu_grad, grad)
|
||||
new_grad = accu_grad * reciprocal(scale)
|
||||
accu_grad = F.depend(accu_grad, new_grad)
|
||||
zeros = F.tensor_mul(accu_grad, 0.0)
|
||||
_ = F.assign(accu_grad, zeros)
|
||||
return new_grad
|
||||
|
||||
class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
||||
"""
|
||||
|
@ -102,7 +103,7 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
|||
self.optimizer = optimizer
|
||||
self.default_lr = Tensor([0.0], dtype=mstype.float32)
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.clip = ClipByGlobalNorm(self.weights)
|
||||
self.clip = ClipByGlobalNorm(self.weights, config)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, input_ids, input_position=None, attention_mask=None, layer_past=None, sens=None):
|
||||
|
@ -142,3 +143,111 @@ class PanguAlphaTrainOneStepWithLossScaleCell(TrainOneStepWithLossScaleCell):
|
|||
else:
|
||||
succ = self.optimizer(grads)
|
||||
return F.depend(loss, succ), cond, scaling_sens
|
||||
|
||||
class PanguAlphaTrainPipelineWithLossScaleCell(nn.Cell):
|
||||
"""
|
||||
Encapsulation class of PanguAlpha network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
"""
|
||||
def __init__(self, network, optimizer, config, scale_update_cell=None, enable_global_norm=True):
|
||||
super(PanguAlphaTrainPipelineWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.config = config
|
||||
self.network = network
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros")
|
||||
self.optimizer = optimizer
|
||||
self.enable_global_norm = enable_global_norm
|
||||
self.grad = C.GradOperation(get_by_list=True,
|
||||
sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.allreduce = P.AllReduce()
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus().add_prim_attr("_side_effect_flag", False)
|
||||
self.get_status = P.NPUGetFloatStatus().add_prim_attr("_side_effect_flag", False)
|
||||
self.clear_before_grad = P.NPUClearFloatStatus().add_prim_attr("_side_effect_flag", False)
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
#self.depend_parameter_use = P.ControlDepend(depend_mode=1)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.reshape = P.Reshape()
|
||||
#self.control = P.ControlDepend(1)
|
||||
#self.clip_norm = Tensor(1000.0, mstype.float32)
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
|
||||
name="loss_scale")
|
||||
self.clip = ClipByGlobalNorm(self.weights, self.config)
|
||||
self.micro_size = config.micro_size
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self,
|
||||
input_ids,
|
||||
input_position,
|
||||
attention_mask,
|
||||
past=None,
|
||||
sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(input_ids, input_position, attention_mask)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
scaling_sens = self.reshape(scaling_sens, (1,))
|
||||
else:
|
||||
scaling_sens = sens
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
status_clear = self.clear_before_grad(init)
|
||||
#clear_depend = self.control(status_clear, self.weights)
|
||||
grads = self.grad(self.network, weights)(input_ids,
|
||||
input_position,
|
||||
attention_mask,
|
||||
self.cast(scaling_sens / self.micro_size,
|
||||
mstype.float32))
|
||||
init = F.depend(init, grads)
|
||||
get_status = self.get_status(init)
|
||||
init = F.depend(init, get_status)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
loss = F.depend(loss, status_clear)
|
||||
# apply grad reducer on grads
|
||||
accu_grads = self.grad_reducer(self.accu_grads)
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads, accu_grads)
|
||||
if self.enable_global_norm:
|
||||
grads, _ = self.clip(grads)
|
||||
else:
|
||||
grads = self.hyper_map(
|
||||
F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE),
|
||||
grads)
|
||||
if self.is_distributed:
|
||||
# sum overflow flag over devices
|
||||
flag_reduce = self.allreduce(flag_sum)
|
||||
cond = self.less_equal(self.base, flag_reduce)
|
||||
else:
|
||||
cond = self.less_equal(self.base, flag_sum)
|
||||
overflow = cond
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, cond)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
ret = (loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
|
|
@ -26,8 +26,8 @@ from mindspore.ops import functional as F
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR, CosineDecayLR
|
||||
from mindspore.parallel._utils import _get_global_rank
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
from mindspore.communication.management import get_rank, get_group_size, create_group
|
||||
from mindspore.nn import AdamWeightDecay
|
||||
from mindspore.common import Parameter, ParameterTuple
|
||||
from mindspore.common.initializer import initializer
|
||||
|
@ -71,14 +71,12 @@ class FP32StateAdamWeightDecay(AdamWeightDecay):
|
|||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
|
||||
|
||||
@get_square_sum.register("Tensor", "Tensor")
|
||||
@get_square_sum.register("Tensor", "Number")
|
||||
def _get_square_sum(grad, value):
|
||||
norm = P.ReduceSum(False)(F.square(grad) / value, ())
|
||||
norm = P.ReduceSum(False)(F.square(grad), ()) / value
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
|
||||
|
||||
|
@ -87,6 +85,41 @@ def _apply_global_norm(clip_norm, global_norm, grad):
|
|||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
def _get_model_parallel_group(mp):
|
||||
"""
|
||||
|
||||
Calculate the communication group of model parallel dim in one pipeline stage
|
||||
|
||||
"""
|
||||
rank = get_rank()
|
||||
stage_nums = auto_parallel_context().get_pipeline_stages()
|
||||
device_nums = get_group_size()
|
||||
per_stage_device_nums = device_nums // stage_nums
|
||||
stage_id = rank // per_stage_device_nums
|
||||
local_stage_rank_id = rank % per_stage_device_nums
|
||||
index = local_stage_rank_id // mp
|
||||
group = range(0, mp)
|
||||
rank_str_list = [str(x + index * mp + stage_id * per_stage_device_nums) for x in group]
|
||||
rank_list_str = "-".join(rank_str_list)
|
||||
rank_list = [x + index * mp + stage_id * per_stage_device_nums for x in group]
|
||||
return rank_list, rank_list_str
|
||||
|
||||
def _get_pipeline_group():
|
||||
"""
|
||||
|
||||
Calculate the communication group between all pipeline stages
|
||||
|
||||
"""
|
||||
rank = get_rank()
|
||||
stage_nums = auto_parallel_context().get_pipeline_stages()
|
||||
device_nums = get_group_size()
|
||||
per_stage_device_nums = device_nums // stage_nums
|
||||
local_stage_rank_id = rank % per_stage_device_nums
|
||||
group = range(0, stage_nums)
|
||||
rank_list = [local_stage_rank_id + x * per_stage_device_nums for x in group]
|
||||
rank_str_list = [str(local_stage_rank_id + x * per_stage_device_nums) for x in group]
|
||||
rank_list_str = "-".join(rank_str_list)
|
||||
return rank_list, rank_list_str
|
||||
|
||||
class GlobalNorm(nn.Cell):
|
||||
"""
|
||||
|
@ -107,9 +140,9 @@ class GlobalNorm(nn.Cell):
|
|||
self.group_size = get_group_size()
|
||||
for item in self.allreduce_filter:
|
||||
if item:
|
||||
self.values.append(Tensor([1.0], mstype.float32))
|
||||
self.values.append(1.0)
|
||||
else:
|
||||
self.values.append(Tensor([self.group_size * 1.0], mstype.float32))
|
||||
self.values.append(self.group_size * 1.0)
|
||||
self.values = tuple(self.values)
|
||||
|
||||
def construct(self, grads):
|
||||
|
@ -119,6 +152,37 @@ class GlobalNorm(nn.Cell):
|
|||
global_norms = F.sqrt(P.AllReduce()(F.addn(square_sum_dp)))
|
||||
return global_norms
|
||||
|
||||
class GlobalNormPipline(nn.Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
def __init__(self, params, config):
|
||||
super(GlobalNormPipline, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.allreduce_filter = tuple("projection.bias" not in x.name and "layernorm" not in x.name
|
||||
and "position_embedding.embedding_table" not in x.name for x in params)
|
||||
self.allreduce_group_size = ()
|
||||
for item in self.allreduce_filter:
|
||||
if item:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (1.0,)
|
||||
else:
|
||||
self.allreduce_group_size = self.allreduce_group_size + (config.mp * 1.0,)
|
||||
self.length = len(params)
|
||||
group_list, group_name = _get_model_parallel_group(config.mp)
|
||||
create_group(group_name, group_list)
|
||||
self.allreduce = P.AllReduce(group=group_name)
|
||||
pipeline_group_list, pipeline_group_name = _get_pipeline_group()
|
||||
create_group(pipeline_group_name, pipeline_group_list)
|
||||
self.allreduce2 = P.AllReduce(group=pipeline_group_name)
|
||||
|
||||
def construct(self, grads):
|
||||
square_sum = self.hyper_map(get_square_sum, grads, self.allreduce_group_size)
|
||||
square_reduce_sum = F.addn(square_sum)
|
||||
stage_square_reduce_sum = self.allreduce(square_reduce_sum)
|
||||
global_square_reduce_sum = self.allreduce2(stage_square_reduce_sum)
|
||||
global_norms = F.sqrt(global_square_reduce_sum)
|
||||
return global_norms
|
||||
|
||||
class ClipByGlobalNorm(nn.Cell):
|
||||
"""
|
||||
|
@ -126,10 +190,12 @@ class ClipByGlobalNorm(nn.Cell):
|
|||
Clip grads by global norm
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, params, clip_norm=1.0):
|
||||
def __init__(self, params, config, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm(params)
|
||||
if config.stage_num > 1:
|
||||
self.global_norm = GlobalNormPipline(params, config)
|
||||
else:
|
||||
self.global_norm = GlobalNorm(params)
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
|
@ -140,14 +206,6 @@ class ClipByGlobalNorm(nn.Cell):
|
|||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
return grads, global_norm_value
|
||||
|
||||
|
||||
def _get_model_parallel_group(dp, mp):
|
||||
rank = _get_global_rank()
|
||||
group = range(0, mp)
|
||||
index = rank // dp
|
||||
return [x + index * mp for x in group]
|
||||
|
||||
|
||||
class LearningRate(LearningRateSchedule):
|
||||
"""
|
||||
Warmup-decay learning rate for PanguAlpha network.
|
||||
|
@ -258,6 +316,10 @@ def add_training_params(opt):
|
|||
type=int,
|
||||
default=2000,
|
||||
help="Warmup step, default is 2000.")
|
||||
opt.add_argument("--decay_steps",
|
||||
type=int,
|
||||
default=200000,
|
||||
help="Decay step, default is 200000.")
|
||||
opt.add_argument("--optimizer",
|
||||
type=str,
|
||||
default="adam",
|
||||
|
@ -298,6 +360,11 @@ def add_training_params(opt):
|
|||
type=int,
|
||||
default=8,
|
||||
help="The model parallel way. default 8")
|
||||
opt.add_argument("--word_emb_dp",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[0, 1],
|
||||
help="Whether do data parallel in word embedding. default 1")
|
||||
|
||||
|
||||
def get_args(inference=False):
|
||||
|
|
|
@ -29,10 +29,11 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore.parallel import set_algo_parameters
|
||||
from mindspore.parallel._cost_model_context import _set_multi_subgraphs
|
||||
|
||||
from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell
|
||||
from src.dataset import create_dataset
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell
|
||||
from src.pangu_alpha import PanguAlpha, PanguAlphaWithLoss,\
|
||||
PanguAlphaPipeline, PanguAlphaWithLossPipeline, CrossEntropyLoss
|
||||
from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainPipelineWithLossScaleCell
|
||||
from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
|
||||
from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
|
||||
from src.utils import download_data
|
||||
|
@ -132,14 +133,14 @@ def run_train(args_opt):
|
|||
micro_size=args_opt.micro_size,
|
||||
eod_reset=bool(args_opt.eod_reset),
|
||||
param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
|
||||
word_emb_dp=True)
|
||||
word_emb_dp=bool(args_opt.word_emb_dp))
|
||||
print("===config is: ", config, flush=True)
|
||||
|
||||
# Define network
|
||||
pangu_alpha = PanguAlpha(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
pangu_alpha_with_loss = PanguAlphaWithLoss(config, pangu_alpha, loss)
|
||||
pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
|
||||
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss)
|
||||
|
||||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
|
||||
|
@ -189,8 +190,111 @@ def run_train(args_opt):
|
|||
print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
|
||||
model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)
|
||||
|
||||
def run_train_pipeline(args_opt):
|
||||
r"""
|
||||
The main training process in pipeline.
|
||||
"""
|
||||
device_id = int(os.getenv("DEVICE_ID"))
|
||||
context.set_context(save_graphs=False,
|
||||
mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
device_id=device_id)
|
||||
context.set_context(variable_memory_max_size="31GB")
|
||||
if args_opt.distribute == "true":
|
||||
D.init()
|
||||
device_num = D.get_group_size()
|
||||
rank_id = D.get_rank()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(
|
||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||
gradients_mean=False,
|
||||
device_num=device_num,
|
||||
full_batch=True,
|
||||
loss_repeated_mean=True,
|
||||
enable_parallel_optimizer=bool(args_opt.optimizer_shard),
|
||||
pipeline_stages=args_opt.stage_num)
|
||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||
_set_multi_subgraphs()
|
||||
else:
|
||||
rank_id = int(os.getenv("RANK_ID"))
|
||||
device_num = 1
|
||||
model_parallel_num = args_opt.op_level_model_parallel_num
|
||||
stage_device_num = int(device_num / args_opt.stage_num)
|
||||
data_parallel_num = int(stage_device_num / model_parallel_num)
|
||||
per_batch_size = args_opt.per_batch_size
|
||||
batch_size = per_batch_size * data_parallel_num * args_opt.micro_size
|
||||
config = PANGUALPHAConfig(
|
||||
data_parallel_num=data_parallel_num,
|
||||
model_parallel_num=model_parallel_num,
|
||||
batch_size=batch_size,
|
||||
seq_length=args_opt.seq_length,
|
||||
vocab_size=args_opt.vocab_size,
|
||||
embedding_size=args_opt.embedding_size,
|
||||
num_layers=args_opt.num_layers,
|
||||
num_heads=args_opt.num_heads,
|
||||
expand_ratio=4,
|
||||
post_layernorm_residual=False,
|
||||
dropout_rate=0.1,
|
||||
compute_dtype=mstype.float16,
|
||||
use_past=False,
|
||||
self_layernorm=True,
|
||||
stage_num=args_opt.stage_num,
|
||||
micro_size=args_opt.micro_size,
|
||||
word_emb_dp=bool(args_opt.word_emb_dp))
|
||||
print("===config is: ", config, flush=True)
|
||||
pangu_alpha = PanguAlphaPipeline(config)
|
||||
loss = CrossEntropyLoss(config)
|
||||
pangu_alpha_with_loss = PipelineCell(PanguAlphaWithLossPipeline(config, pangu_alpha, loss), config.micro_size)
|
||||
pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss)
|
||||
print("=====args_opt is: ", args_opt, flush=True)
|
||||
lr = LearningRate(learning_rate=args_opt.start_lr,
|
||||
end_learning_rate=args_opt.end_lr,
|
||||
warmup_steps=args_opt.warmup_step,
|
||||
decay_steps=args_opt.decay_steps)
|
||||
params = pangu_alpha.infer_param_pipeline_stage()
|
||||
decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
|
||||
decay_params = list(filter(decay_filter, params))
|
||||
other_params = list(filter(lambda x: not decay_filter(x), params))
|
||||
group_params = [{
|
||||
'params': decay_params,
|
||||
'weight_decay': 1e-1
|
||||
}, {
|
||||
'params': other_params,
|
||||
'weight_decay': 0.0
|
||||
}, {
|
||||
'order_params': params
|
||||
}]
|
||||
if args_opt.optimizer == "lamb":
|
||||
optimizer = nn.Lamb(group_params, learning_rate=lr)
|
||||
else:
|
||||
optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8)
|
||||
ds = create_dataset(config.batch_size, data_path=args_opt.data_url, eod_reset=True,
|
||||
data_start_index=0, full_batch=True)
|
||||
epoch_num = args_opt.epoch_size
|
||||
step_per_epoch = ds.get_dataset_size()
|
||||
callback_size = args_opt.sink_size
|
||||
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
||||
callback = [
|
||||
TimeMonitor(callback_size),
|
||||
LossCallBack(callback_size, rank_id, config.stage_num)
|
||||
]
|
||||
loss_scale_value = math.pow(2, 32)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
|
||||
scale_factor=2,
|
||||
scale_window=1000)
|
||||
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
||||
pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell)
|
||||
model = Model(pangu_alpha_with_grads)
|
||||
model.train(actual_epoch_num,
|
||||
ds,
|
||||
callbacks=callback,
|
||||
sink_size=callback_size,
|
||||
dataset_sink_mode=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = get_args()
|
||||
set_parse(opt)
|
||||
run_train(opt)
|
||||
if opt.stage_num > 1:
|
||||
run_train_pipeline(opt)
|
||||
else:
|
||||
run_train(opt)
|
||||
|
|
Loading…
Reference in New Issue