forked from mindspore-Ecosystem/mindspore
parent
fbfb42a062
commit
c3a98bab2b
|
@ -553,7 +553,7 @@ class _CellGraphExecutor:
|
|||
"""compile graph in auto parallel mode."""
|
||||
if not auto_parallel_mode:
|
||||
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
|
||||
self._updata_param_node_default_input(phase, replace)
|
||||
self._update_param_node_default_input(phase, replace)
|
||||
return
|
||||
|
||||
obj.parameter_layout_dict = self._graph_executor.get_parameter_layout(phase)
|
||||
|
@ -564,13 +564,13 @@ class _CellGraphExecutor:
|
|||
if not context.get_context("enable_debug_runtime") or context.get_context("enable_ge"):
|
||||
obj.load_parameter_slice(None)
|
||||
|
||||
self._updata_param_node_default_input(phase, replace)
|
||||
self._update_param_node_default_input(phase, replace)
|
||||
|
||||
# set parallel inputs in sink mode
|
||||
if is_sink_mode:
|
||||
obj.set_parallel_input_with_inputs(*args)
|
||||
|
||||
def _updata_param_node_default_input(self, phase, replace):
|
||||
def _update_param_node_default_input(self, phase, replace):
|
||||
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
|
||||
return self._graph_executor.updata_param_node_default_input(phase, new_param)
|
||||
|
||||
|
|
|
@ -332,15 +332,14 @@ class LeakyReLU(Cell):
|
|||
validator.check_value_type('alpha', alpha, [float, int], self.cls_name)
|
||||
self.greater_equal = P.GreaterEqual()
|
||||
self.mul = P.Mul()
|
||||
self.maximum = P.Maximum()
|
||||
self.alpha = alpha
|
||||
self.select_op = P.Maximum()
|
||||
if self.alpha > 1:
|
||||
self.select_op = P.Minimum()
|
||||
|
||||
def construct(self, x):
|
||||
alpha_array = P.Cast()(F.scalar_to_array(self.alpha), P.DType()(x))
|
||||
if self.alpha <= 1:
|
||||
out = self.maximum(alpha_array * x, x)
|
||||
else:
|
||||
out = self.maximum(alpha_array * x, x)
|
||||
out = self.select_op(alpha_array * x, x)
|
||||
return out
|
||||
|
||||
|
||||
|
|
|
@ -149,9 +149,9 @@ def _set_fusion_strategy_by_size(data_size_list, group="hccl_world_group"):
|
|||
if not isinstance(data_size, (int, float)):
|
||||
raise TypeError('data_size in data_size_list is invalid')
|
||||
|
||||
c_array_sizeList = _c_array(ctypes.c_float, data_size_list)
|
||||
c_array_size_list = _c_array(ctypes.c_float, data_size_list)
|
||||
c_size_num = ctypes.c_uint(len(data_size_list))
|
||||
c_group = _c_str(group)
|
||||
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_sizeList)
|
||||
ret = lib_ctype.hcom_set_split_strategy_by_size(c_group, c_size_num, c_array_size_list)
|
||||
if ret != 0:
|
||||
raise RuntimeError('Allreduce split error')
|
||||
|
|
|
@ -30,9 +30,11 @@ def _get_parallel_mode():
|
|||
"""Get parallel mode."""
|
||||
return auto_parallel_context().get_parallel_mode()
|
||||
|
||||
|
||||
def _is_in_auto_parallel_mode():
|
||||
return _get_parallel_mode() in [ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL]
|
||||
|
||||
|
||||
def _get_full_batch():
|
||||
"""Get whether to use full_batch."""
|
||||
return auto_parallel_context().get_full_batch()
|
||||
|
@ -51,12 +53,8 @@ def _check_task_sink_envs():
|
|||
"""
|
||||
import os
|
||||
task_sink = os.getenv("GRAPH_OP_RUN")
|
||||
if task_sink:
|
||||
try:
|
||||
if int(task_sink) == 1:
|
||||
return False
|
||||
except ValueError:
|
||||
return True
|
||||
if task_sink and task_sink.isdigit() and int(task_sink) == 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
|
|
@ -18,9 +18,9 @@ NOTE:
|
|||
This is an experimental interface that is subject to change and/or deletion.
|
||||
"""
|
||||
from .transformer import *
|
||||
from .layers import *
|
||||
from .loss import *
|
||||
from .op_parallel_config import *
|
||||
from .layers import FixedSparseAttention
|
||||
from .loss import CrossEntropyLoss
|
||||
from .op_parallel_config import OpParallelConfig
|
||||
|
||||
__all__ = []
|
||||
__all__.extend(transformer.__all__)
|
||||
|
|
|
@ -73,6 +73,7 @@ def _valid_type_checks(types, class_name):
|
|||
# as the input of _args_type_validator_check is fixed, so we need to manually change the input order
|
||||
partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
|
||||
return partial_check(name, type(value))
|
||||
|
||||
return validator_check_func
|
||||
|
||||
|
||||
|
@ -83,6 +84,7 @@ def _valid_value_checks(types, class_name):
|
|||
# as the input of _args_type_validator_check is fixed, so we need to manually change the input order
|
||||
partial_check = partial(Validator.check_type_name, valid_types=types, prim_name=class_name)
|
||||
return partial_check(name, value)
|
||||
|
||||
return validator_check_func
|
||||
|
||||
|
||||
|
@ -334,7 +336,7 @@ class _Linear(Cell):
|
|||
if self.activation_flag:
|
||||
# some operations has many primitives, need to manually set the shard
|
||||
if self.act_name.lower() == "leakyrelu":
|
||||
self.activation.maximum.shard((strategy_activation[0], strategy_activation[0]))
|
||||
self.activation.select_op.shard((strategy_activation[0], strategy_activation[0]))
|
||||
elif self.act_name.lower() == "logsigmoid":
|
||||
self.activation.mul.shard((strategy_activation[0], ()))
|
||||
self.activation.exp.shard(strategy_activation)
|
||||
|
@ -402,6 +404,7 @@ class FixedSparseAttention(nn.Cell):
|
|||
>>> print(output.shape)
|
||||
(2, 1024, 512)
|
||||
"""
|
||||
|
||||
@_args_type_validator_check(batch_size=Validator.check_positive_int,
|
||||
num_heads=Validator.check_positive_int,
|
||||
size_per_head=Validator.check_positive_int,
|
||||
|
@ -437,7 +440,7 @@ class FixedSparseAttention(nn.Cell):
|
|||
self.transpose = P.Transpose().shard(((dp, 1, mp, 1),))
|
||||
self.batch_matmul = P.BatchMatMul().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
|
||||
self.multiply = P.Mul().shard(((dp, 1, 1, 1), (1, 1, 1)))
|
||||
self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32)
|
||||
self.multiply_data = Tensor([-10000.0], dtype=mstype.float32)
|
||||
self.parallel_config = parallel_config
|
||||
size_per_head_list = [64, 128]
|
||||
if self.seq_length != 1024:
|
||||
|
@ -460,7 +463,7 @@ class FixedSparseAttention(nn.Cell):
|
|||
global_mask_original = -10000 * global_mask_original
|
||||
global_mask_fx = global_mask_original.reshape((self.seq_length // 16, 16, self.global_size // 16, 16))
|
||||
global_mask = np.transpose(global_mask_fx, (2, 0, 1, 3))
|
||||
global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :,], self.batch_size, axis=0)
|
||||
global_mask = np.repeat(global_mask[np.newaxis, :, :, :, :], self.batch_size, axis=0)
|
||||
global_mask = global_mask.reshape((self.batch_size * self.global_size // 16, self.seq_length // 16, 16, 16))
|
||||
self.global_mask = Tensor(global_mask, mstype.float32)
|
||||
self.local_mask_triangle = Tensor(np.tril(local_ones), mstype.float32)
|
||||
|
@ -578,6 +581,7 @@ class _CumSum(Cell):
|
|||
Outputs:
|
||||
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(_CumSum, self).__init__()
|
||||
dp = config.data_parallel
|
||||
|
@ -598,14 +602,13 @@ class _CumSum(Cell):
|
|||
self.delta = Tensor(1, mstype.int32)
|
||||
self.add = P.TensorAdd().shard(((1,), ()))
|
||||
|
||||
|
||||
def construct(self, expert_mask):
|
||||
# origin_shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
|
||||
# origin_shape: (expert_parallel, tokens_per_device, self.expert_dim)
|
||||
origin_shape = self.shape(expert_mask)
|
||||
tokens_per_device = origin_shape[1]
|
||||
# expert_mask_trans's shape: (self.expert_parallel, self.expert_dim, tokens_per_device)
|
||||
# expert_mask_trans's shape: (expert_parallel, self.expert_dim, tokens_per_device)
|
||||
expert_mask_trans = self.transpose(expert_mask, (0, 2, 1))
|
||||
# expert_mask_reshaped's shape: (self.expert_parallel*self.expert_dim, tokens_per_device)
|
||||
# expert_mask_reshaped's shape: (expert_parallel*self.expert_dim, tokens_per_device)
|
||||
expert_mask_reshaped = self.reshape(expert_mask_trans, (-1, tokens_per_device))
|
||||
|
||||
one_dim = self.expand(self.range(self.start, self.add(self.limit, tokens_per_device), self.delta), 0)
|
||||
|
@ -614,11 +617,11 @@ class _CumSum(Cell):
|
|||
up_tri_matrix = self.greater(one_dim, other_dim)
|
||||
up_tri_matrix = self.cast(up_tri_matrix, mstype.float32)
|
||||
|
||||
# cum_sum's shape: (self.expert_parallel*self.expert_dim, tokens_per_device)
|
||||
# cum_sum's shape: (expert_parallel*self.expert_dim, tokens_per_device)
|
||||
cum_sum = self.matmul(expert_mask_reshaped, up_tri_matrix)
|
||||
# cum_sum's shape: (self.expert_parallel, self.expert_dim, tokens_per_device)
|
||||
# cum_sum's shape: (expert_parallel, self.expert_dim, tokens_per_device)
|
||||
cum_sum = self.reshape(cum_sum, (origin_shape[0], origin_shape[2], tokens_per_device))
|
||||
# cum_sum's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
|
||||
# cum_sum's shape: (expert_parallel, tokens_per_device, self.expert_dim)
|
||||
cum_sum = self.transpose3(cum_sum, (0, 2, 1))
|
||||
return cum_sum
|
||||
|
||||
|
@ -646,6 +649,7 @@ class Router(Cell):
|
|||
Outputs:
|
||||
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim)`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
moe_config,
|
||||
|
@ -704,6 +708,7 @@ class SwitchRouter(Cell):
|
|||
Tensor of shape :math:`(expert\_parallel, tokens\_per\_device, expert\_dim, expert\_capacity)`,
|
||||
Tensor of shape :math:`(1)`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
moe_config,
|
||||
|
@ -752,9 +757,9 @@ class SwitchRouter(Cell):
|
|||
"""
|
||||
Computing the load balance loss.
|
||||
"""
|
||||
# density_1's shape: (self.expert_parallel, self.expert_dim)
|
||||
# density_1's shape: (expert_parallel, self.expert_dim)
|
||||
density_1 = self.reduce_mean(expert_mask, 1)
|
||||
# density_1_proxy's shape: (self.expert_parallel, self.expert_dim)
|
||||
# density_1_proxy's shape: (expert_parallel, self.expert_dim)
|
||||
density_1_proxy = self.reduce_mean2(router_prob, 1)
|
||||
loss = self.mul(density_1, density_1_proxy)
|
||||
loss = self.reduce_mean3(loss)
|
||||
|
@ -766,20 +771,19 @@ class SwitchRouter(Cell):
|
|||
Keeping only the tokens that fit within expert_capacity.
|
||||
"""
|
||||
cumsum = self.cumsum(expert_mask)
|
||||
# position_in_expert's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
|
||||
# position_in_expert's shape: (expert_parallel, tokens_per_device, self.expert_dim)
|
||||
position_in_expert = self.mul4(cumsum, expert_mask)
|
||||
less_result = self.less(position_in_expert, expert_capacity)
|
||||
# expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
|
||||
# expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim)
|
||||
expert_mask = self.mul5(less_result, expert_mask)
|
||||
# expert_mask_flat's shape: (self.expert_parallel, tokens_per_device)
|
||||
# expert_mask_flat's shape: (expert_parallel, tokens_per_device)
|
||||
expert_mask_flat = self.reduce_sum(expert_mask, -1)
|
||||
|
||||
# Mask out the experts that have overflowed the expert_capacity.
|
||||
# expert_gate's shape: (self.expert_parallel, tokens_per_device)
|
||||
# expert_gate's shape: (expert_parallel, tokens_per_device)
|
||||
expert_gate = self.mul6(expert_gate, expert_mask_flat)
|
||||
return expert_gate, expert_mask_flat, position_in_expert
|
||||
|
||||
|
||||
def construct(self, router_logits):
|
||||
router_logits_shape = self.shape(router_logits)
|
||||
router_logits = self.reshape(router_logits, (-1, router_logits_shape[-1]))
|
||||
|
@ -791,9 +795,9 @@ class SwitchRouter(Cell):
|
|||
|
||||
# Probabilities for each token of what expert is should be sent to
|
||||
router_prob = self.softmax(router_logits)
|
||||
# shape: (self.expert_parallel, tokens_per_device)
|
||||
# shape is : (expert_parallel, tokens_per_device)
|
||||
expert_index, expert_gate = self.argmax(router_prob)
|
||||
# expert_mask's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
|
||||
# expert_mask's shape: (expert_parallel, tokens_per_device, self.expert_dim)
|
||||
expert_mask = self.onehot(expert_index, self.expert_dim, self.on_value, self.off_value)
|
||||
|
||||
# Computing the load balance loss:
|
||||
|
@ -802,12 +806,12 @@ class SwitchRouter(Cell):
|
|||
expert_gate, expert_mask_flat, position_in_expert = \
|
||||
self._maskout_overflowed_tokens(expert_mask, expert_capacity, expert_gate)
|
||||
|
||||
# combine_tensor's shape: (self.expert_parallel, tokens_per_device)
|
||||
# combine_tensor's shape: (expert_parallel, tokens_per_device)
|
||||
combine_tensor = self.mul7(expert_gate, expert_mask_flat)
|
||||
# combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim)
|
||||
# combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim)
|
||||
combine_tensor = self.mul8(self.expand(combine_tensor, -1),
|
||||
self.onehot2(expert_index, self.expert_dim, self.on_value, self.off_value))
|
||||
# combine_tensor's shape: (self.expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity)
|
||||
# combine_tensor's shape: (expert_parallel, tokens_per_device, self.expert_dim, self.expert_capacity)
|
||||
combine_tensor = self.mul9(self.expand2(combine_tensor, -1),
|
||||
self.onehot3(self.cast(position_in_expert, mstype.int32), expert_capacity,
|
||||
self.on_value, self.off_value))
|
||||
|
|
|
@ -93,7 +93,7 @@ class CrossEntropyLoss(Cell):
|
|||
"""
|
||||
self._check_input(logits, label, input_mask)
|
||||
|
||||
# [bs*seq_length, vocab_size]
|
||||
# the shape is [bs*seq_length, vocab_size]
|
||||
logits = F.cast(logits, mstype.float32)
|
||||
# LogSoftmax for logits over last dimension
|
||||
_, logit_max = self.max(logits)
|
||||
|
|
|
@ -143,7 +143,6 @@ def _check_config(config):
|
|||
device_num = D.get_group_size()
|
||||
optimizer_shard = context.get_auto_parallel_context("enable_parallel_optimizer")
|
||||
|
||||
# dp * pp * pipeline_stage <= device_num
|
||||
if config.data_parallel * config.model_parallel * pipeline_stage > device_num:
|
||||
raise ValueError(f"The product of the data parallel {config.data_parallel}, "
|
||||
f"model parallel {config.model_parallel} "
|
||||
|
@ -155,10 +154,3 @@ def _check_config(config):
|
|||
logger.warning(f"The optimizer shard {optimizer_shard} in auto_parallel_context is not equal to the"
|
||||
f" optimizer_shard {config.optimizer_shard} in the OpParallelConfig. Please check the "
|
||||
f"optimizer_shard to make them consistent.")
|
||||
|
||||
# pipeline_stage <= micro_batch_num
|
||||
if hasattr(config, 'pipeline_stage') and hasattr(config, 'micro_batch_num')\
|
||||
and config.pipeline_stage < config.micro_batch_num:
|
||||
raise ValueError(
|
||||
f"The pipeline stage {config.pipeline_stage} should be greater than the micro_batch_num "
|
||||
f"{config.micro_batch_num}.")
|
||||
|
|
|
@ -396,13 +396,14 @@ class FeedForward(Cell):
|
|||
_check_input_shape(F.shape(x), "x", self.cls_name, 3)
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
x = self.cast(x, mstype.float16)
|
||||
# [bs, seq_length, ffn_hidden_size]
|
||||
# returned shape is [bs, seq_length, ffn_hidden_size]
|
||||
hidden = self.mapping(x)
|
||||
output = self.projection(hidden)
|
||||
# [bs, seq_length, hidden_size]
|
||||
# returned shape is [bs, seq_length, hidden_size]
|
||||
output = self.dropout(output)
|
||||
return output
|
||||
|
||||
|
||||
@constexpr
|
||||
def calculate_expert_capacity(k, tokens_per_device, capacity_factor, expert_dim):
|
||||
return math.ceil(k * tokens_per_device * capacity_factor / expert_dim)
|
||||
|
@ -588,7 +589,7 @@ class AttentionMask(Cell):
|
|||
mask_right = self.reshape(input_mask, shape_right)
|
||||
attention_mask = self.mul(mask_left, mask_right)
|
||||
lower_traiangle = self.expand_dim(self.lower_triangle_mask, 0)
|
||||
# [bs, seq_length, seq_length]
|
||||
# the returned shape is [bs, seq_length, seq_length]
|
||||
attention_mask = self.multiply(
|
||||
attention_mask, lower_traiangle)
|
||||
return attention_mask
|
||||
|
@ -889,18 +890,18 @@ class MultiHeadAttention(Cell):
|
|||
query = self.dense1(query_tensor)
|
||||
key = self.dense2(key_tensor)
|
||||
value = self.dense3(value_tensor)
|
||||
# [bs, num_heads, seq_length, size_per_head]
|
||||
# the returned shape is [bs, num_heads, seq_length, size_per_head]
|
||||
query = self.transpose(
|
||||
F.reshape(
|
||||
query,
|
||||
(-1, query_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
||||
(0, 2, 1, 3))
|
||||
# [bs, num_heads, size_per_head, seq_length]
|
||||
# the returned shape is [bs, num_heads, size_per_head, seq_length]
|
||||
key = self.transpose(
|
||||
F.reshape(
|
||||
key, (-1, key_tensor_original_shape[1], self.n_head, self.size_per_head)),
|
||||
(0, 2, 3, 1))
|
||||
# [bs, num_heads, seq_length, size_per_head]
|
||||
# the returned shape is [bs, num_heads, seq_length, size_per_head]
|
||||
value = self.transpose(
|
||||
F.reshape(
|
||||
value,
|
||||
|
@ -949,7 +950,7 @@ class MultiHeadAttention(Cell):
|
|||
|
||||
layer_present = (key_present, value_present)
|
||||
# multi head attention considering attention mask
|
||||
# [bs, seq_length, hidden_size]
|
||||
# the return shape is [bs, seq_length, hidden_size]
|
||||
attention = self._attn(query, key, value, attention_mask)
|
||||
# Output
|
||||
output = self.projection(attention)
|
||||
|
@ -1019,8 +1020,8 @@ class MultiHeadAttention(Cell):
|
|||
ori_dtype = P.DType()(score)
|
||||
score = P.Cast()(score, self.softmax_dtype)
|
||||
|
||||
# for input size of (bs, 1) namely the second graph, the shape of attention_mask matrix should be
|
||||
# (bs, 1, 1, seq_length)
|
||||
# for input size of (bs, 1) namely the second graph,
|
||||
# the shape of attention_mask matrix should be (bs, 1, 1, seq_length)
|
||||
if self.use_past and not self.is_first_iteration:
|
||||
# Calculate the current total token
|
||||
current_index = self.reducesum(F.cast(self.not_equal(self.slice(key, (0, 0, 0, 0),
|
||||
|
@ -1508,7 +1509,7 @@ class TransformerDecoderLayer(Cell):
|
|||
memory_mask=None,
|
||||
init_reset=True, batch_valid_length=None):
|
||||
self._check_input(hidden_stats, decoder_mask, encoder_output, memory_mask, init_reset, batch_valid_length)
|
||||
# [bs, seq_length, embedding_size]
|
||||
# the returned shape is [bs, seq_length, embedding_size]
|
||||
input_x = self.layernorm1(hidden_stats)
|
||||
input_x = F.cast(input_x, self.dtype)
|
||||
|
||||
|
|
|
@ -33,20 +33,19 @@ class EmbeddingLayer(nn.Cell):
|
|||
def __init__(self, config):
|
||||
super(EmbeddingLayer, self).__init__()
|
||||
# Only for the pipeline mode, the embedding needs to be row sliced.
|
||||
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
||||
if copied_parallel_config.pipeline_stage > 1:
|
||||
copied_parallel_config.vocab_emb_dp = False
|
||||
self.word_embedding = VocabEmbedding(vocab_size=config.vocab_size,
|
||||
embedding_size=config.hidden_size,
|
||||
param_init=initializer("normal", [config.vocab_size, config.hidden_size],
|
||||
dtype=config.param_init_type),
|
||||
parallel_config=copied_parallel_config.embedding_dp_mp_config)
|
||||
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
||||
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
||||
copied_parallel_config.vocab_emb_dp = True
|
||||
self.position_embedding = VocabEmbedding(vocab_size=config.seq_length,
|
||||
embedding_size=config.hidden_size,
|
||||
param_init=initializer("normal",
|
||||
[config.seq_length, config.hidden_size],
|
||||
dtype=config.param_init_type),
|
||||
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
||||
parallel_config=copied_parallel_config.embedding_dp_mp_config)
|
||||
self.add = P.Add().shard(
|
||||
((config.parallel_config.data_parallel, 1, 1), (config.parallel_config.data_parallel, 1, 1)))
|
||||
self.dropout = nn.Dropout(1 - config.dropout_rate)
|
||||
|
@ -249,13 +248,14 @@ class PanguAlpha_Model(Cell):
|
|||
param_init_type=config.param_init_type,
|
||||
use_past=config.use_past,
|
||||
parallel_config=config.parallel_config).blocks
|
||||
|
||||
copied_parallel_config = copy.deepcopy(config.parallel_config)
|
||||
copied_parallel_config.vocab_emb_dp = True
|
||||
self.top_query_embedding = VocabEmbedding(vocab_size=config.seq_length,
|
||||
embedding_size=config.hidden_size,
|
||||
param_init=initializer("normal",
|
||||
[config.seq_length, config.hidden_size],
|
||||
dtype=config.param_init_type),
|
||||
parallel_config=config.parallel_config.embedding_dp_mp_config)
|
||||
parallel_config=copied_parallel_config.embedding_dp_mp_config)
|
||||
self.top_query_embedding.pipeline_stage = config.parallel_config.pipeline_stage - 1
|
||||
if config.parallel_config.pipeline_stage > 1:
|
||||
self.top_query_embedding.set_comm_fusion(2)
|
||||
|
|
|
@ -106,6 +106,7 @@ def run_train(args_opt):
|
|||
pipeline_stage=args_opt.stage_num,
|
||||
micro_batch_num=args_opt.micro_size,
|
||||
optimizer_shard=bool(args_opt.optimizer_shard),
|
||||
vocab_emb_dp=bool(args_opt.word_emb_dp),
|
||||
recompute=True)
|
||||
config = PanguAlphaConfig(batch_size=batch_size, num_heads=args_opt.num_heads,
|
||||
hidden_size=args_opt.embedding_size, seq_length=args_opt.seq_length,
|
||||
|
@ -221,6 +222,7 @@ def run_train_pipeline(args_opt):
|
|||
pipeline_stage=args_opt.stage_num,
|
||||
micro_batch_num=args_opt.micro_size,
|
||||
optimizer_shard=bool(args_opt.optimizer_shard),
|
||||
vocab_emb_dp=bool(args_opt.word_emb_dp),
|
||||
recompute=True)
|
||||
config = PanguAlphaConfig(batch_size=batch_size // parallel_config.micro_batch_num,
|
||||
num_heads=args_opt.num_heads, hidden_size=args_opt.embedding_size,
|
||||
|
|
Loading…
Reference in New Issue