forked from mindspore-Ecosystem/mindspore
add pangu gpu
This commit is contained in:
parent
327ba1962a
commit
da67a91c14
|
@ -37,33 +37,21 @@ def load_model(args_opt):
|
||||||
r"""
|
r"""
|
||||||
The main function for load model
|
The main function for load model
|
||||||
"""
|
"""
|
||||||
device_id = int(os.getenv("DEVICE_ID"))
|
|
||||||
rank_id_str = os.getenv('RANK_ID', '0')
|
|
||||||
rank_id = int(
|
|
||||||
rank_id_str[rank_id_str.rfind('-') +
|
|
||||||
1:])
|
|
||||||
print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
|
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
|
||||||
local_rank = rank_id
|
|
||||||
print('local_rank:{}, device id:{} start to run...'.format(local_rank, device_id), flush=True)
|
|
||||||
# Set execution mode
|
# Set execution mode
|
||||||
context.set_context(save_graphs=False,
|
context.set_context(save_graphs=False,
|
||||||
mode=context.GRAPH_MODE,
|
mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend",
|
device_target=args_opt.device_target)
|
||||||
device_id=device_id)
|
|
||||||
context.set_context(variable_memory_max_size="30GB")
|
context.set_context(variable_memory_max_size="30GB")
|
||||||
# Set parallel context
|
# Set parallel context
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
device_num = D.get_group_size()
|
device_num = D.get_group_size()
|
||||||
rank = D.get_rank()
|
rank = D.get_rank()
|
||||||
print("device_id is {}, rank_id is {}, device_num is {}".format(
|
print("rank_id is {}, device_num is {}".format(rank, device_num))
|
||||||
device_id, rank, device_num))
|
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||||
gradients_mean=False,
|
gradients_mean=False,
|
||||||
device_num=device_num,
|
|
||||||
full_batch=True,
|
full_batch=True,
|
||||||
loss_repeated_mean=True,
|
loss_repeated_mean=True,
|
||||||
enable_parallel_optimizer=False,
|
enable_parallel_optimizer=False,
|
||||||
|
@ -75,7 +63,7 @@ def load_model(args_opt):
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
print('local_rank:{}, start to run...'.format(rank), flush=True)
|
||||||
use_past = False
|
use_past = False
|
||||||
if args_opt.export:
|
if args_opt.export:
|
||||||
use_past = True
|
use_past = True
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "bash run_distributed_train_gpu.sh RANK_SIZE HOSTFILE DATASET MODE"
|
||||||
|
echo "for example: bash run_distributed_train_gpu.sh 16 hostfile_16p /mass_dataset/train_data/ 2.6B"
|
||||||
|
echo "It is better to use absolute path."
|
||||||
|
echo "=============================================================================================================="
|
||||||
|
|
||||||
|
script_self=$(readlink -f "$0")
|
||||||
|
self_path=$(dirname "${script_self}")
|
||||||
|
RANK_SIZE=$1
|
||||||
|
HOSTFILE=$2
|
||||||
|
DATASET=$3
|
||||||
|
MODE=$4
|
||||||
|
|
||||||
|
mpirun --allow-run-as-root -x PATH -x LD_LIBRARY_PATH -x PYTHONPATH -x NCCL_DEBUG -x GLOG_v -n $RANK_SIZE --hostfile $HOSTFILE --output-filename log_output --merge-stderr-to-stdout \
|
||||||
|
python -s ${self_path}/../train.py \
|
||||||
|
--distribute=true \
|
||||||
|
--device_num=$RANK_SIZE \
|
||||||
|
--device_target="GPU" \
|
||||||
|
--data_url=$DATASET \
|
||||||
|
--mode=$MODE \
|
||||||
|
--run_type=train > train_log.txt 2>&1 &
|
|
@ -76,6 +76,13 @@ class Dropout(nn.Cell):
|
||||||
def extend_repr(self):
|
def extend_repr(self):
|
||||||
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
|
return 'keep_prob={}, dtype={}'.format(self.keep_prob, self.dtype)
|
||||||
|
|
||||||
|
def shard(self, strategy):
|
||||||
|
if self.is_ascend:
|
||||||
|
self.dropout_gen_mask.shard(strategy)
|
||||||
|
self.dropout_do_mask.shard(strategy)
|
||||||
|
else:
|
||||||
|
self.dropout.shard(strategy)
|
||||||
|
|
||||||
class LayerNorm(nn.Cell):
|
class LayerNorm(nn.Cell):
|
||||||
r"""
|
r"""
|
||||||
A self-defined layer norm operation using reduce sum and reduce mean
|
A self-defined layer norm operation using reduce sum and reduce mean
|
||||||
|
@ -147,7 +154,7 @@ class Mapping(nn.Cell):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class Mapping_output(nn.Cell):
|
class MappingOutput(nn.Cell):
|
||||||
"""
|
"""
|
||||||
A mapping function with a 3d input
|
A mapping function with a 3d input
|
||||||
Args:
|
Args:
|
||||||
|
@ -161,7 +168,7 @@ class Mapping_output(nn.Cell):
|
||||||
output: Tensor, a 3d tensor after projection
|
output: Tensor, a 3d tensor after projection
|
||||||
"""
|
"""
|
||||||
def __init__(self, config, input_size, output_size, scale=1.0):
|
def __init__(self, config, input_size, output_size, scale=1.0):
|
||||||
super(Mapping_output, self).__init__()
|
super(MappingOutput, self).__init__()
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
|
self.weight = Parameter(initializer(Normal(sigma=0.02 * scale),
|
||||||
|
@ -203,14 +210,13 @@ class Output(nn.Cell):
|
||||||
input_size = config.embedding_size
|
input_size = config.embedding_size
|
||||||
output_size = config.embedding_size * config.expand_ratio
|
output_size = config.embedding_size * config.expand_ratio
|
||||||
# Project to expand_ratio*embedding_size
|
# Project to expand_ratio*embedding_size
|
||||||
self.mapping = Mapping_output(config, input_size, output_size)
|
self.mapping = MappingOutput(config, input_size, output_size)
|
||||||
# Project back to embedding_size
|
# Project back to embedding_size
|
||||||
self.projection = Mapping(config, output_size, input_size, scale)
|
self.projection = Mapping(config, output_size, input_size, scale)
|
||||||
self.activation = nn.GELU()
|
self.activation = nn.GELU()
|
||||||
self.activation.gelu.shard(((config.dp, 1, config.mp),))
|
self.activation.gelu.shard(((config.dp, 1, config.mp),))
|
||||||
self.dropout = Dropout(1 - config.dropout_rate)
|
self.dropout = Dropout(1 - config.dropout_rate)
|
||||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
self.dropout.shard(((config.dp, 1, 1),))
|
||||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
# [bs, seq_length, expand_ratio*embedding_size]
|
# [bs, seq_length, expand_ratio*embedding_size]
|
||||||
|
@ -282,13 +288,9 @@ class Attention(nn.Cell):
|
||||||
self.coeff = Tensor(self.coeff)
|
self.coeff = Tensor(self.coeff)
|
||||||
self.use_past = config.use_past
|
self.use_past = config.use_past
|
||||||
self.dropout = Dropout(1 - config.dropout_rate)
|
self.dropout = Dropout(1 - config.dropout_rate)
|
||||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
self.dropout.shard(((config.dp, 1, 1),))
|
||||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
|
||||||
self.prob_dropout = Dropout(1 - config.dropout_rate)
|
self.prob_dropout = Dropout(1 - config.dropout_rate)
|
||||||
self.prob_dropout.dropout_gen_mask.shard(
|
self.prob_dropout.shard(((config.dp, config.mp, 1, 1),))
|
||||||
((config.dp, config.mp, 1, 1),))
|
|
||||||
self.prob_dropout.dropout_do_mask.shard(
|
|
||||||
((config.dp, config.mp, 1, 1),))
|
|
||||||
self.softmax = nn.Softmax()
|
self.softmax = nn.Softmax()
|
||||||
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
self.softmax.softmax.shard(((config.dp, config.mp, 1),))
|
||||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||||
|
@ -631,12 +633,12 @@ class Decoder(nn.Cell):
|
||||||
output = self.add(x, mlp_logit)
|
output = self.add(x, mlp_logit)
|
||||||
return output, layer_present
|
return output, layer_present
|
||||||
|
|
||||||
class EmbeddingCell(nn.Cell):
|
class Embedding(nn.Cell):
|
||||||
"""
|
"""
|
||||||
EmbeddingCell
|
Embedding
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(EmbeddingCell, self).__init__()
|
super(Embedding, self).__init__()
|
||||||
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
|
self.word_embedding = EmbeddingLookup().set_comm_fusion(1)
|
||||||
if config.word_emb_dp:
|
if config.word_emb_dp:
|
||||||
self.word_embedding.gather.shard(((1, 1), (config.dp, 1)))
|
self.word_embedding.gather.shard(((1, 1), (config.dp, 1)))
|
||||||
|
@ -669,8 +671,7 @@ class EmbeddingCell(nn.Cell):
|
||||||
self.position_embedding.expand.shard(((config.dp, 1),))
|
self.position_embedding.expand.shard(((config.dp, 1),))
|
||||||
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
self.add = P.TensorAdd().shard(((config.dp, 1, 1), (config.dp, 1, 1)))
|
||||||
self.dropout = Dropout(1 - config.dropout_rate)
|
self.dropout = Dropout(1 - config.dropout_rate)
|
||||||
self.dropout.dropout_gen_mask.shard(((config.dp, 1, 1),))
|
self.dropout.shard(((config.dp, 1, 1),))
|
||||||
self.dropout.dropout_do_mask.shard(((config.dp, 1, 1),))
|
|
||||||
self.use_past = config.use_past
|
self.use_past = config.use_past
|
||||||
self.is_first_iteration = True
|
self.is_first_iteration = True
|
||||||
|
|
||||||
|
@ -686,12 +687,12 @@ class EmbeddingCell(nn.Cell):
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class MaskCell(nn.Cell):
|
class Mask(nn.Cell):
|
||||||
"""
|
"""
|
||||||
MaskCell
|
Mask
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(MaskCell, self).__init__()
|
super(Mask, self).__init__()
|
||||||
self.dtype = config.compute_dtype
|
self.dtype = config.compute_dtype
|
||||||
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
self.expand_dims = P.ExpandDims().shard(((config.dp, 1, 1),))
|
||||||
def construct(self, attention_mask):
|
def construct(self, attention_mask):
|
||||||
|
@ -871,10 +872,10 @@ class PanguAlphaEmbedding(nn.Cell):
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(PanguAlphaEmbedding, self).__init__()
|
super(PanguAlphaEmbedding, self).__init__()
|
||||||
self.embedding = EmbeddingCell(config)
|
self.embedding = Embedding(config)
|
||||||
if config.stage_num > 1:
|
if config.stage_num > 1:
|
||||||
self.embedding.pipeline_stage = 0
|
self.embedding.pipeline_stage = 0
|
||||||
self.mask = MaskCell(config)
|
self.mask = Mask(config)
|
||||||
|
|
||||||
def construct(self, input_ids, input_mask, table, input_position, attention_mask, valid_index=None):
|
def construct(self, input_ids, input_mask, table, input_position, attention_mask, valid_index=None):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -397,6 +397,11 @@ def get_args(inference=False):
|
||||||
default="2.6B",
|
default="2.6B",
|
||||||
choices=["200B", "13B", "2.6B", "self_define"],
|
choices=["200B", "13B", "2.6B", "self_define"],
|
||||||
help="The scale of the model parameters")
|
help="The scale of the model parameters")
|
||||||
|
parser.add_argument("--device_target",
|
||||||
|
type=str,
|
||||||
|
default="Ascend",
|
||||||
|
choices=["Ascend", "GPU"],
|
||||||
|
help="The running device")
|
||||||
parser.add_argument("--strategy_load_ckpt_path",
|
parser.add_argument("--strategy_load_ckpt_path",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
|
|
|
@ -44,12 +44,13 @@ class LossCallBack(Callback):
|
||||||
If the loss in NAN or INF terminating training.
|
If the loss in NAN or INF terminating training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0):
|
def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1):
|
||||||
super(LossCallBack, self).__init__()
|
super(LossCallBack, self).__init__()
|
||||||
self._dataset_size = dataset_size
|
self._dataset_size = dataset_size
|
||||||
self.local_rank = local_rank
|
self.local_rank = local_rank
|
||||||
self.has_trained_epoch = has_trained_epoch
|
self.has_trained_epoch = has_trained_epoch
|
||||||
self.has_trained_step = has_trained_step
|
self.has_trained_step = has_trained_step
|
||||||
|
self.micro_size = micro_size
|
||||||
print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
|
print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
|
||||||
|
|
||||||
def step_end(self, run_context):
|
def step_end(self, run_context):
|
||||||
|
@ -63,9 +64,10 @@ class LossCallBack(Callback):
|
||||||
if percent == 0:
|
if percent == 0:
|
||||||
epoch_num -= 1
|
epoch_num -= 1
|
||||||
date = time.asctime(time.localtime(time.time()))
|
date = time.asctime(time.localtime(time.time()))
|
||||||
|
loss_value = cb_params.net_outputs[0].asnumpy() / self.micro_size
|
||||||
print("time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}".
|
print("time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}".
|
||||||
format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch),
|
format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch),
|
||||||
cb_params.cur_step_num + int(self.has_trained_step), cb_params.net_outputs[0].asnumpy(),
|
cb_params.cur_step_num + int(self.has_trained_step), loss_value,
|
||||||
cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy()))
|
cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy()))
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,25 +80,21 @@ def run_train(args_opt):
|
||||||
r"""
|
r"""
|
||||||
The main training process.
|
The main training process.
|
||||||
"""
|
"""
|
||||||
device_id = int(os.getenv('DEVICE_ID'))
|
|
||||||
# Set execution mode
|
# Set execution mode
|
||||||
context.set_context(mode=context.GRAPH_MODE,
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
device_target="Ascend",
|
device_target=args_opt.device_target)
|
||||||
device_id=device_id)
|
|
||||||
context.set_context(variable_memory_max_size="30GB")
|
context.set_context(variable_memory_max_size="30GB")
|
||||||
# Set parallel context
|
# Set parallel context
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
device_num = D.get_group_size()
|
device_num = D.get_group_size()
|
||||||
rank = D.get_rank()
|
rank = D.get_rank()
|
||||||
print("device_id is {}, rank_id is {}, device_num is {}".format(
|
print("rank_id is {}, device_num is {}".format(rank, device_num))
|
||||||
device_id, rank, device_num))
|
|
||||||
|
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||||
gradients_mean=False,
|
gradients_mean=False,
|
||||||
device_num=device_num,
|
|
||||||
full_batch=bool(args_opt.full_batch),
|
full_batch=bool(args_opt.full_batch),
|
||||||
enable_parallel_optimizer=bool(args_opt.optimizer_shard))
|
enable_parallel_optimizer=bool(args_opt.optimizer_shard))
|
||||||
set_algo_parameters(elementwise_op_strategy_follow=True)
|
set_algo_parameters(elementwise_op_strategy_follow=True)
|
||||||
|
@ -105,7 +103,7 @@ def run_train(args_opt):
|
||||||
else:
|
else:
|
||||||
rank = 0
|
rank = 0
|
||||||
device_num = 1
|
device_num = 1
|
||||||
|
context.set_context(save_graphs=False, save_graphs_path="./graphs_of_device_id_" + str(rank))
|
||||||
# copy data from the cloud to the /cache/Data
|
# copy data from the cloud to the /cache/Data
|
||||||
cache_url = '/cache/Data/'
|
cache_url = '/cache/Data/'
|
||||||
if args_opt.offline:
|
if args_opt.offline:
|
||||||
|
@ -194,18 +192,17 @@ def run_train_pipeline(args_opt):
|
||||||
r"""
|
r"""
|
||||||
The main training process in pipeline.
|
The main training process in pipeline.
|
||||||
"""
|
"""
|
||||||
device_id = int(os.getenv("DEVICE_ID"))
|
context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||||
context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
|
||||||
context.set_context(variable_memory_max_size="31GB")
|
context.set_context(variable_memory_max_size="31GB")
|
||||||
if args_opt.distribute == "true":
|
if args_opt.distribute == "true":
|
||||||
D.init()
|
D.init()
|
||||||
device_num = D.get_group_size()
|
device_num = D.get_group_size()
|
||||||
rank_id = D.get_rank()
|
rank_id = D.get_rank()
|
||||||
|
print("rank_id is {}, device_num is {}".format(rank_id, device_num))
|
||||||
context.reset_auto_parallel_context()
|
context.reset_auto_parallel_context()
|
||||||
context.set_auto_parallel_context(
|
context.set_auto_parallel_context(
|
||||||
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
|
||||||
gradients_mean=False,
|
gradients_mean=False,
|
||||||
device_num=device_num,
|
|
||||||
full_batch=bool(args_opt.full_batch),
|
full_batch=bool(args_opt.full_batch),
|
||||||
loss_repeated_mean=True,
|
loss_repeated_mean=True,
|
||||||
enable_parallel_optimizer=bool(args_opt.optimizer_shard),
|
enable_parallel_optimizer=bool(args_opt.optimizer_shard),
|
||||||
|
@ -281,7 +278,7 @@ def run_train_pipeline(args_opt):
|
||||||
step_per_epoch = ds.get_dataset_size()
|
step_per_epoch = ds.get_dataset_size()
|
||||||
callback_size = args_opt.sink_size
|
callback_size = args_opt.sink_size
|
||||||
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
|
||||||
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num)]
|
callback = [TimeMonitor(callback_size), LossCallBack(callback_size, rank_id, config.stage_num, config.micro_size)]
|
||||||
loss_scale_value = math.pow(2, 32)
|
loss_scale_value = math.pow(2, 32)
|
||||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
|
||||||
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
pangu_alpha_with_grads = PanguAlphaTrainPipelineWithLossScaleCell(
|
||||||
|
|
Loading…
Reference in New Issue