add pangu gpu

This commit is contained in:
yao_yf 2021-07-06 12:46:38 +08:00
parent 327ba1962a
commit da67a91c14
5 changed files with 78 additions and 49 deletions

View File

@ -37,33 +37,21 @@ def load_model(args_opt):
r""" r"""
The main function for load model The main function for load model
""" """
device_id = int(os.getenv("DEVICE_ID"))
rank_id_str = os.getenv('RANK_ID', '0')
rank_id = int(
rank_id_str[rank_id_str.rfind('-') +
1:])
print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
device_id = int(os.getenv('DEVICE_ID'))
local_rank = rank_id
print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True)
# Set execution mode # Set execution mode
context.set_context(save_graphs=False, context.set_context(save_graphs=False,
mode=context.GRAPH_MODE, mode=context.GRAPH_MODE,
device_target="Ascend", device_target=args_opt.device_target)
device_id=device_id)
context.set_context(variable_memory_max_size="30GB") context.set_context(variable_memory_max_size="30GB")
# Set parallel context # Set parallel context
if args_opt.distribute == "true": if args_opt.distribute == "true":
D.init() D.init()
device_num = D.get_group_size() device_num = D.get_group_size()
rank = D.get_rank() rank = D.get_rank()
print("device_id is {}, rank_id is {}, device_num is {}".format( print("rank_id is {}, device_num is {}".format(rank, device_num))
device_id, rank, device_num))
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False, gradients_mean=False,
device_num=device_num,
full_batch=True, full_batch=True,
loss_repeated_mean=True, loss_repeated_mean=True,
enable_parallel_optimizer=False, enable_parallel_optimizer=False,
@ -75,7 +63,7 @@ def load_model(args_opt):
else: else:
rank = 0 rank = 0
device_num = 1 device_num = 1
print('local_rank:{}, start to run...'.format(rank), flush=True)
use_past = False use_past = False
if args_opt.export: if args_opt.export:
use_past = True use_past = True

View File

@ -0,0 +1,38 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_distributed_train_gpu.sh RANK_SIZE HOSTFILE DATASET MODE"
echo "for example: bash run_distributed_train_gpu.sh 16 hostfile_16p /mass_dataset/train_data/ 2.6B"
echo "It is better to use absolute path."
echo "=============================================================================================================="
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
RANK_SIZE=$1
HOSTFILE=$2
DATASET=$3
MODE=$4
mpirun --allow-run-as-root -x PATH -x LD_LIBRARY_PATH -x PYTHONPATH -x NCCL_DEBUG -x GLOG_v -n $RANK_SIZE --hostfile $HOSTFILE --output-filename log_output --merge-stderr-to-stdout \
python -s ${self_path}/../train.py \
--distribute=true \
--device_num=$RANK_SIZE \
--device_target="GPU" \
--data_url=$DATASET \
--mode=$MODE \
--run_type=train > train_log.txt 2>&1 &

View File

@ -76,6 +76,13 @@ class Dropout(nn.Cell):
def extend_repr(self): def extend_repr(self):
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype) return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
def shard(self, strategy):
if self.is_ascend:
self.dropout_gen_mask.shard(strategy)
self.dropout_do_mask.shard(strategy)
else:
self.dropout.shard(strategy)
class LayerNorm(nn.Cell): class LayerNorm(nn.Cell):
r""" r"""
A self-defined layer norm operation using reduce sum and reduce mean A self-defined layer norm operation using reduce sum and reduce mean
@ -147,7 +154,7 @@ class Mapping(nn.Cell):
return output return output
class Mapping_output(nn.Cell): class MappingOutput(nn.Cell):
""" """
A mapping function with a 3d input A mapping function with a 3d input
Args: Args:
@ -161,7 +168,7 @@ class Mapping_output(nn.Cell):
output: Tensor, a 3d tensor after projection output: Tensor, a 3d tensor after projection
""" """
def __init__(self, config, input_size, output_size, scale=1.0): def __init__(self, config, input_size, output_size, scale=1.0):
super(Mapping_output, self).__init__() super(MappingOutput, self).__init__()
self.output_size = output_size self.output_size = output_size
self.input_size = input_size self.input_size = input_size
self.weight = Parameter(initializer(Normal(sigma=0.02 * scale), self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
@ -203,14 +210,13 @@ class Output(nn.Cell):
input_size = config.embedding_size input_size = config.embedding_size
output_size = config.embedding_size * config.expand_ratio output_size = config.embedding_size * config.expand_ratio
# Project to expand_ratio*embedding_size # Project to expand_ratio*embedding_size
self.mapping = Mapping_output(config, input_size, output_size) self.mapping = MappingOutput(config, input_size, output_size)
# Project back to embedding_size # Project back to embedding_size
self.projection = Mapping(config, output_size, input_size, scale) self.projection = Mapping(config, output_size, input_size, scale)
self.activation = nn.GELU() self.activation = nn.GELU()
self.activation.gelu.shard(((config.dp, 1, config.mp),)) self.activation.gelu.shard(((config.dp, 1, config.mp),))
self.dropout = Dropout(1 - config.dropout_rate) self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) self.dropout.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
def construct(self, x): def construct(self, x):
# [bs, seq_length, expand_ratio*embedding_size] # [bs, seq_length, expand_ratio*embedding_size]
@ -282,13 +288,9 @@ class Attention(nn.Cell):
self.coeff = Tensor(self.coeff) self.coeff = Tensor(self.coeff)
self.use_past = config.use_past self.use_past = config.use_past
self.dropout = Dropout(1 - config.dropout_rate) self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) self.dropout.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
self.prob_dropout = Dropout(1 - config.dropout_rate) self.prob_dropout = Dropout(1 - config.dropout_rate)
self.prob_dropout.dropout_gen_mask.shard( self.prob_dropout.shard(((config.dp, config.mp, 1, 1),))
((config.dp, config.mp, 1, 1),))
self.prob_dropout.dropout_do_mask.shard(
((config.dp, config.mp, 1, 1),))
self.softmax = nn.Softmax() self.softmax = nn.Softmax()
self.softmax.softmax.shard(((config.dp, config.mp, 1),)) self.softmax.softmax.shard(((config.dp, config.mp, 1),))
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
@ -631,12 +633,12 @@ class Decoder(nn.Cell):
output = self.add(x, mlp_logit) output = self.add(x, mlp_logit)
return output, layer_present return output, layer_present
class EmbeddingCell(nn.Cell): class Embedding(nn.Cell):
""" """
EmbeddingCell Embedding
""" """
def __init__(self, config): def __init__(self, config):
super(EmbeddingCell, self).__init__() super(Embedding, self).__init__()
self.word_embedding = EmbeddingLookup().set_comm_fusion(1) self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
if config.word_emb_dp: if config.word_emb_dp:
self.word_embedding.gather.shard(((1, 1), (config.dp, 1))) self.word_embedding.gather.shard(((1, 1), (config.dp, 1)))
@ -669,8 +671,7 @@ class EmbeddingCell(nn.Cell):
self.position_embedding.expand.shard(((config.dp, 1),)) self.position_embedding.expand.shard(((config.dp, 1),))
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1))) self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
self.dropout = Dropout(1 - config.dropout_rate) self.dropout = Dropout(1 - config.dropout_rate)
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),)) self.dropout.shard(((config.dp, 1, 1),))
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
self.use_past = config.use_past self.use_past = config.use_past
self.is_first_iteration = True self.is_first_iteration = True
@ -686,12 +687,12 @@ class EmbeddingCell(nn.Cell):
return hidden_states return hidden_states
class MaskCell(nn.Cell): class Mask(nn.Cell):
""" """
MaskCell Mask
""" """
def __init__(self, config): def __init__(self, config):
super(MaskCell, self).__init__() super(Mask, self).__init__()
self.dtype = config.compute_dtype self.dtype = config.compute_dtype
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),)) self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
def construct(self, attention_mask): def construct(self, attention_mask):
@ -871,10 +872,10 @@ class PanguAlphaEmbedding(nn.Cell):
""" """
def __init__(self, config): def __init__(self, config):
super(PanguAlphaEmbedding, self).__init__() super(PanguAlphaEmbedding, self).__init__()
self.embedding = EmbeddingCell(config) self.embedding = Embedding(config)
if config.stage_num > 1: if config.stage_num > 1:
self.embedding.pipeline_stage = 0 self.embedding.pipeline_stage = 0
self.mask = MaskCell(config) self.mask = Mask(config)
def construct(self, input_ids, input_mask, table, input_position, attention_mask, valid_index=None): def construct(self, input_ids, input_mask, table, input_position, attention_mask, valid_index=None):
""" """

View File

@ -397,6 +397,11 @@ def get_args(inference=False):
default="2.6B", default="2.6B",
choices=["200B", "13B", "2.6B", "self_define"], choices=["200B", "13B", "2.6B", "self_define"],
help="The scale of the model parameters") help="The scale of the model parameters")
parser.add_argument("--device_target",
type=str,
default="Ascend",
choices=["Ascend", "GPU"],
help="The running device")
parser.add_argument("--strategy_load_ckpt_path", parser.add_argument("--strategy_load_ckpt_path",
type=str, type=str,
default="", default="",

View File

@ -44,12 +44,13 @@ class LossCallBack(Callback):
If the loss in NAN or INF terminating training. If the loss in NAN or INF terminating training.
""" """
def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0): def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1):
super(LossCallBack, self).__init__() super(LossCallBack, self).__init__()
self._dataset_size = dataset_size self._dataset_size = dataset_size
self.local_rank = local_rank self.local_rank = local_rank
self.has_trained_epoch = has_trained_epoch self.has_trained_epoch = has_trained_epoch
self.has_trained_step = has_trained_step self.has_trained_step = has_trained_step
self.micro_size = micro_size
print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True) print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
def step_end(self, run_context): def step_end(self, run_context):
@ -63,9 +64,10 @@ class LossCallBack(Callback):
if percent == 0: if percent == 0:
epoch_num -= 1 epoch_num -= 1
date = time.asctime(time.localtime(time.time())) date = time.asctime(time.localtime(time.time()))
loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size
print("time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}". print("time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}".
format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch), format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch),
cb_params.cur_step_num + int(self.has_trained_step), cb_params.net_outputs[0].asnumpy(), cb_params.cur_step_num + int(self.has_trained_step), loss_value,
cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy())) cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy()))
@ -78,25 +80,21 @@ def run_train(args_opt):
r""" r"""
The main training process. The main training process.
""" """
device_id = int(os.getenv('DEVICE_ID'))
# Set execution mode # Set execution mode
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target=args_opt.device_target)
device_id=device_id)
context.set_context(variable_memory_max_size="30GB") context.set_context(variable_memory_max_size="30GB")
# Set parallel context # Set parallel context
if args_opt.distribute == "true": if args_opt.distribute == "true":
D.init() D.init()
device_num = D.get_group_size() device_num = D.get_group_size()
rank = D.get_rank() rank = D.get_rank()
print("device_id is {}, rank_id is {}, device_num is {}".format( print("rank_id is {}, device_num is {}".format(rank, device_num))
device_id, rank, device_num))
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False, gradients_mean=False,
device_num=device_num,
full_batch=bool(args_opt.full_batch), full_batch=bool(args_opt.full_batch),
enable_parallel_optimizer=bool(args_opt.optimizer_shard)) enable_parallel_optimizer=bool(args_opt.optimizer_shard))
set_algo_parameters(elementwise_op_strategy_follow=True) set_algo_parameters(elementwise_op_strategy_follow=True)
@ -105,7 +103,7 @@ def run_train(args_opt):
else: else:
rank = 0 rank = 0
device_num = 1 device_num = 1
context.set_context(save_graphs=False, save_graphs_path="./graphs_of_device_id_" + str(rank))
# copy data from the cloud to the /cache/Data # copy data from the cloud to the /cache/Data
cache_url = '/cache/Data/' cache_url = '/cache/Data/'
if args_opt.offline: if args_opt.offline:
@ -194,18 +192,17 @@ def run_train_pipeline(args_opt):
r""" r"""
The main training process in pipeline. The main training process in pipeline.
""" """
device_id = int(os.getenv("DEVICE_ID")) context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
context.set_context(variable_memory_max_size="31GB") context.set_context(variable_memory_max_size="31GB")
if args_opt.distribute == "true": if args_opt.distribute == "true":
D.init() D.init()
device_num = D.get_group_size() device_num = D.get_group_size()
rank_id = D.get_rank() rank_id = D.get_rank()
print("rank_id is {}, device_num is {}".format(rank_id, device_num))
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
context.set_auto_parallel_context( context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
gradients_mean=False, gradients_mean=False,
device_num=device_num,
full_batch=bool(args_opt.full_batch), full_batch=bool(args_opt.full_batch),
loss_repeated_mean=True, loss_repeated_mean=True,
enable_parallel_optimizer=bool(args_opt.optimizer_shard), enable_parallel_optimizer=bool(args_opt.optimizer_shard),
@ -281,7 +278,7 @@ def run_train_pipeline(args_opt):
step_per_epoch = ds.get_dataset_size() step_per_epoch = ds.get_dataset_size()
callback_size = args_opt.sink_size callback_size = args_opt.sink_size
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num)] callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num, config.micro_size)]
loss_scale_value = math.pow(2, 32) loss_scale_value = math.pow(2, 32)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell( pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(