diff --git a/mindspore/nn/optim/thor.py b/mindspore/nn/optim/thor.py index 6e34500c46c..2e22ea0d0d6 100644 --- a/mindspore/nn/optim/thor.py +++ b/mindspore/nn/optim/thor.py @@ -755,13 +755,13 @@ class ThorAscend(Optimizer): if self.conv_layer_count > 0: auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4") - self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=3) - self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=5) + self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2) + self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4) auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8") - self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=9) - self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=17) + self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6) + self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8) def _process_matrix_init_and_weight_idx_map(self, net): """for Ascend, process matrix init shape, and get weight idx map""" diff --git a/model_zoo/official/nlp/bert/pretrain_config.yaml b/model_zoo/official/nlp/bert/pretrain_config.yaml index 8bd072292c0..01aa73cd379 100644 --- a/model_zoo/official/nlp/bert/pretrain_config.yaml +++ b/model_zoo/official/nlp/bert/pretrain_config.yaml @@ -26,7 +26,7 @@ accumulation_steps: 1 allreduce_post_accumulation: 'true' save_checkpoint_path: '' load_checkpoint_path: '' -save_checkpoint_steps: 1000 +save_checkpoint_steps: 10000 train_steps: -1 save_checkpoint_num: 1 data_dir: '' diff --git a/model_zoo/official/nlp/bert/pretrain_config_Ascend_Thor.yaml b/model_zoo/official/nlp/bert/pretrain_config_Ascend_Thor.yaml new file mode 100644 index 00000000000..93670c29e50 --- /dev/null +++ b/model_zoo/official/nlp/bert/pretrain_config_Ascend_Thor.yaml @@ -0,0 +1,194 @@ +# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing) +enable_modelarts: False +# Url for modelarts +data_url: "" +train_url: "" +checkpoint_url: "" +# Path for local +data_path: "/cache/data" +output_path: "/cache/train" +load_path: "/cache/checkpoint_path" +device_target: "Ascend" +enable_profiling: False + +# ============================================================================== +description: 'run_pretrain' +distribute: 'false' +epoch_size: 40 +device_id: 0 +device_num: 1 +enable_save_ckpt: 'true' +enable_lossscale: 'false' +do_shuffle: 'true' +enable_data_sink: 'true' +data_sink_steps: 100 +accumulation_steps: 1 +allreduce_post_accumulation: 'true' +save_checkpoint_path: '' +load_checkpoint_path: '' +save_checkpoint_steps: 500 +train_steps: 2500 +save_checkpoint_num: 5 +data_dir: '' +schema_dir: '' + +# ============================================================================== +# pretrain related +batch_size: 20 +# Available: [base, nezha, large, large_acc] +bert_network: 'large_acc' +loss_scale_value: 65536 +scale_factor: 2 +scale_window: 1000 +optimizer: 'Thor' +enable_global_norm: False +# pretrain_eval related +data_file: "" +schema_file: null +finetune_ckpt: "" +# optimizer related +AdamWeightDecay: + learning_rate: 0.00003 # 3e-5 + end_learning_rate: 0.0 + power: 5.0 + weight_decay: 0.00001 # 1e-5 + decay_filter: ['layernorm', 'bias'] + eps: 0.000001 # 1e-6 + warmup_steps: 10000 + +Lamb: + learning_rate: 0.0003 # 3e-4 + end_learning_rate: 0.0 + power: 2.0 + warmup_steps: 10000 + weight_decay: 0.01 + decay_filter: ['layernorm', 'bias'] + eps: 0.00000001 # 1e-8, + +Momentum: + learning_rate: 0.00002 # 2e-5 + momentum: 0.9 + +Thor: + lr_max: 0.006464 + lr_min: 0.000001 # 1e-6 + lr_power: 2.0 + lr_total_steps: 30000 + damping_max: 0.007035 + damping_min: 0.000001 # 1e-6 + damping_power: 4.0 + damping_total_steps: 30000 + momentum: 0.9 + weight_decay: 0.00001 # 1e-5 + loss_scale: 1024.0 + frequency: 100 +# ============================================================================== +# base +base_batch_size: 256 +base_net_cfg: + seq_length: 128 + vocab_size: 21128 + hidden_size: 768 + num_hidden_layers: 12 + num_attention_heads: 12 + intermediate_size: 3072 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 2 + initializer_range: 0.02 + use_relative_positions: False + dtype: mstype.float32 + compute_type: mstype.float16 +# nezha +nezha_batch_size: 96 +nezha_net_cfg: + seq_length: 128 + vocab_size: 21128 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 + intermediate_size: 4096 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 2 + initializer_range: 0.02 + use_relative_positions: True + dtype: mstype.float32 + compute_type: mstype.float16 +# large +large_batch_size: 20 +large_net_cfg: + seq_length: 512 + vocab_size: 30522 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 + intermediate_size: 4096 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 2 + initializer_range: 0.02 + use_relative_positions: False + dtype: mstype.float32 + compute_type: mstype.float16 +# Accelerated large network which is only supported in Ascend yet. +large_acc_batch_size: 20 +large_acc_net_cfg: + seq_length: 512 + vocab_size: 30522 + hidden_size: 1024 + num_hidden_layers: 24 + num_attention_heads: 16 + intermediate_size: 4096 + hidden_act: "fast_gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + max_position_embeddings: 512 + type_vocab_size: 2 + initializer_range: 0.02 + use_relative_positions: False + dtype: mstype.float32 + compute_type: mstype.float16 + + +--- +# Help description for each configuration +enable_modelarts: "Whether training on modelarts, default: False" +data_url: "Url for modelarts" +train_url: "Url for modelarts" +data_path: "The location of the input data." +output_path: "The location of the output file." +device_target: "Running platform, choose from Ascend or CPU, and default is Ascend." +enable_profiling: 'Whether enable profiling while training, default: False' + +distribute: "Run distribute, default is 'false'." +epoch_size: "Epoch size, default is 1." +enable_save_ckpt: "Enable save checkpoint, default is true." +enable_lossscale: "Use lossscale or not, default is not." +do_shuffle: "Enable shuffle for dataset, default is true." +enable_data_sink: "Enable data sink, default is true." +data_sink_steps: "Sink steps for each epoch, default is 1." +accumulation_steps: "Accumulating gradients N times before weight update, default is 1." +allreduce_post_accumulation: "Whether to allreduce after accumulation of N steps or after each step, default is true." +save_checkpoint_path: "Save checkpoint path" +load_checkpoint_path: "Load checkpoint file path" +save_checkpoint_steps: "Save checkpoint steps, default is 1000" +train_steps: "Training Steps, default is -1, meaning run all steps according to epoch number." +save_checkpoint_num: "Save checkpoint numbers, default is 1." +data_dir: "Data path, it is better to use absolute path" +schema_dir: "Schema path, it is better to use absolute path" +--- +# chocies +device_target: ['Ascend', 'GPU'] +distribute: ["true", "false"] +enable_save_ckpt: ["true", "false"] +enable_lossscale: ["true", "false"] +do_shuffle: ["true", "false"] +enable_data_sink: ["true", "false"] +allreduce_post_accumulation: ["true", "false"] diff --git a/model_zoo/official/nlp/bert/pretrain_eval.py b/model_zoo/official/nlp/bert/pretrain_eval.py index 2f516ead4d8..11b5368c3b6 100644 --- a/model_zoo/official/nlp/bert/pretrain_eval.py +++ b/model_zoo/official/nlp/bert/pretrain_eval.py @@ -150,10 +150,11 @@ def MLM_eval(): res = net.eval(dataset, dataset_sink_mode=False) print("==============================================================") for _, v in res.items(): - print("Accuracy is: ") - print(v) + print("Accuracy is: ", v) print("==============================================================") if __name__ == "__main__": + DEVICE_ID = 0 + os.environ['DEVICE_ID'] = str(DEVICE_ID) MLM_eval() diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py index c20b98787c8..db85549506a 100644 --- a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py +++ b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py @@ -47,6 +47,8 @@ def parse_args(): parser.add_argument("--hccl_time_out", type=int, default=120, help="Seconds to determine the hccl time out," "default: 120, which is the same as hccl default config") + parser.add_argument("--hccn_config_file", type=str, default="/etc/hccn.conf", + help="Path of the hccn.conf file.") args = parser.parse_args() return args @@ -128,7 +130,7 @@ def distribute_pretrain(): # get device_ips device_ips = {} physic_logic_ids = {} - with open('/etc/hccn.conf', 'r') as fin: + with open(args.hccn_config_file, 'r') as fin: for hccn_item in fin.readlines(): if hccn_item.strip().startswith('address_'): device_id, device_ip = hccn_item.split('=') @@ -136,7 +138,7 @@ def distribute_pretrain(): device_ips[device_id] = device_ip.strip() if not device_ips: - raise ValueError("There is no address in /etc/hccn.conf") + raise ValueError("There is no address in hccn.conf file.") for logic_id, device_id in enumerate(sorted(device_ips.keys())): physic_logic_ids[device_id] = logic_id diff --git a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini index 5489c9ca75c..4e90f7d31a4 100644 --- a/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini +++ b/model_zoo/official/nlp/bert/scripts/ascend_distributed_launcher/hyper_parameter_config.ini @@ -2,12 +2,11 @@ distribute=true epoch_size=40 enable_save_ckpt=true -enable_lossscale=true do_shuffle=true enable_data_sink=true data_sink_steps=100 accumulation_steps=1 allreduce_post_accumulation=true save_checkpoint_path=./ -save_checkpoint_steps=10000 save_checkpoint_num=1 +config_path=../../pretrain_config.yaml diff --git a/model_zoo/official/nlp/bert/scripts/run_distributed_pretrain_ascend.sh b/model_zoo/official/nlp/bert/scripts/run_distributed_pretrain_ascend.sh index 7813e62fbef..606e0922ce4 100644 --- a/model_zoo/official/nlp/bert/scripts/run_distributed_pretrain_ascend.sh +++ b/model_zoo/official/nlp/bert/scripts/run_distributed_pretrain_ascend.sh @@ -30,6 +30,7 @@ python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_pretrain_cm --data_dir=$1 \ --hccl_config_dir=$2 \ --hccl_time_out=600 \ + --hccn_config_file='/etc/hccn.conf' \ --cmd_file=distributed_cmd.sh bash distributed_cmd.sh diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_thor.py b/tests/st/networks/models/bert/bert_performance/test_bert_thor.py index 69696e3928d..ebbf4692045 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_thor.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_thor.py @@ -53,10 +53,10 @@ bert_net_cfg = BertConfig( seq_length=512, vocab_size=30522, hidden_size=1024, - num_hidden_layers=4, + num_hidden_layers=6, num_attention_heads=16, intermediate_size=4096, - hidden_act="gelu", + hidden_act="fast_gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, @@ -166,7 +166,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num): lr = get_bert_thor_lr() damping = get_bert_thor_damping() - split_indices = None + split_indices = [13, 37, 41] optimizer = thor(net_with_loss, lr, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), split_indices=split_indices, enable_clip_grad=True, frequency=frequency) @@ -233,7 +233,7 @@ def test_bert_thor_8p(): os.system("rm -rf " + str(i)) print("End training...") - assert mean_cost < 69 + assert mean_cost < 96 assert mean_loss < 8.125