!22488 bert script adjust for thor generalization

Merge pull request !22488 from wangshuangling/master
This commit is contained in:
i-robot 2021-08-28 07:38:32 +00:00 committed by Gitee
commit d0ab539ea2
8 changed files with 212 additions and 15 deletions

View File

@ -755,13 +755,13 @@ class ThorAscend(Optimizer):
if self.conv_layer_count > 0:
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=3)
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=5)
self.grad_reducer_amax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=2)
self.grad_reducer_gmax = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=4)
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=9)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=17)
self.grad_reducer_a = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=6)
self.grad_reducer_g = DistributedGradReducer(self.matrix_a_cov, mean, degree, fusion_type=8)
def _process_matrix_init_and_weight_idx_map(self, net):
"""for Ascend, process matrix init shape, and get weight idx map"""

View File

@ -26,7 +26,7 @@ accumulation_steps: 1
allreduce_post_accumulation: 'true'
save_checkpoint_path: ''
load_checkpoint_path: ''
save_checkpoint_steps: 1000
save_checkpoint_steps: 10000
train_steps: -1
save_checkpoint_num: 1
data_dir: ''

View File

@ -0,0 +1,194 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "Ascend"
enable_profiling: False
# ==============================================================================
description: 'run_pretrain'
distribute: 'false'
epoch_size: 40
device_id: 0
device_num: 1
enable_save_ckpt: 'true'
enable_lossscale: 'false'
do_shuffle: 'true'
enable_data_sink: 'true'
data_sink_steps: 100
accumulation_steps: 1
allreduce_post_accumulation: 'true'
save_checkpoint_path: ''
load_checkpoint_path: ''
save_checkpoint_steps: 500
train_steps: 2500
save_checkpoint_num: 5
data_dir: ''
schema_dir: ''
# ==============================================================================
# pretrain related
batch_size: 20
# Available: [base, nezha, large, large_acc]
bert_network: 'large_acc'
loss_scale_value: 65536
scale_factor: 2
scale_window: 1000
optimizer: 'Thor'
enable_global_norm: False
# pretrain_eval related
data_file: ""
schema_file: null
finetune_ckpt: ""
# optimizer related
AdamWeightDecay:
learning_rate: 0.00003 # 3e-5
end_learning_rate: 0.0
power: 5.0
weight_decay: 0.00001 # 1e-5
decay_filter: ['layernorm', 'bias']
eps: 0.000001 # 1e-6
warmup_steps: 10000
Lamb:
learning_rate: 0.0003 # 3e-4
end_learning_rate: 0.0
power: 2.0
warmup_steps: 10000
weight_decay: 0.01
decay_filter: ['layernorm', 'bias']
eps: 0.00000001 # 1e-8,
Momentum:
learning_rate: 0.00002 # 2e-5
momentum: 0.9
Thor:
lr_max: 0.006464
lr_min: 0.000001 # 1e-6
lr_power: 2.0
lr_total_steps: 30000
damping_max: 0.007035
damping_min: 0.000001 # 1e-6
damping_power: 4.0
damping_total_steps: 30000
momentum: 0.9
weight_decay: 0.00001 # 1e-5
loss_scale: 1024.0
frequency: 100
# ==============================================================================
# base
base_batch_size: 256
base_net_cfg:
seq_length: 128
vocab_size: 21128
hidden_size: 768
num_hidden_layers: 12
num_attention_heads: 12
intermediate_size: 3072
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
type_vocab_size: 2
initializer_range: 0.02
use_relative_positions: False
dtype: mstype.float32
compute_type: mstype.float16
# nezha
nezha_batch_size: 96
nezha_net_cfg:
seq_length: 128
vocab_size: 21128
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size: 4096
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
type_vocab_size: 2
initializer_range: 0.02
use_relative_positions: True
dtype: mstype.float32
compute_type: mstype.float16
# large
large_batch_size: 20
large_net_cfg:
seq_length: 512
vocab_size: 30522
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size: 4096
hidden_act: "gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
type_vocab_size: 2
initializer_range: 0.02
use_relative_positions: False
dtype: mstype.float32
compute_type: mstype.float16
# Accelerated large network which is only supported in Ascend yet.
large_acc_batch_size: 20
large_acc_net_cfg:
seq_length: 512
vocab_size: 30522
hidden_size: 1024
num_hidden_layers: 24
num_attention_heads: 16
intermediate_size: 4096
hidden_act: "fast_gelu"
hidden_dropout_prob: 0.1
attention_probs_dropout_prob: 0.1
max_position_embeddings: 512
type_vocab_size: 2
initializer_range: 0.02
use_relative_positions: False
dtype: mstype.float32
compute_type: mstype.float16
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts, default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of the input data."
output_path: "The location of the output file."
device_target: "Running platform, choose from Ascend or CPU, and default is Ascend."
enable_profiling: 'Whether enable profiling while training, default: False'
distribute: "Run distribute, default is 'false'."
epoch_size: "Epoch size, default is 1."
enable_save_ckpt: "Enable save checkpoint, default is true."
enable_lossscale: "Use lossscale or not, default is not."
do_shuffle: "Enable shuffle for dataset, default is true."
enable_data_sink: "Enable data sink, default is true."
data_sink_steps: "Sink steps for each epoch, default is 1."
accumulation_steps: "Accumulating gradients N times before weight update, default is 1."
allreduce_post_accumulation: "Whether to allreduce after accumulation of N steps or after each step, default is true."
save_checkpoint_path: "Save checkpoint path"
load_checkpoint_path: "Load checkpoint file path"
save_checkpoint_steps: "Save checkpoint steps, default is 1000"
train_steps: "Training Steps, default is -1, meaning run all steps according to epoch number."
save_checkpoint_num: "Save checkpoint numbers, default is 1."
data_dir: "Data path, it is better to use absolute path"
schema_dir: "Schema path, it is better to use absolute path"
---
# chocies
device_target: ['Ascend', 'GPU']
distribute: ["true", "false"]
enable_save_ckpt: ["true", "false"]
enable_lossscale: ["true", "false"]
do_shuffle: ["true", "false"]
enable_data_sink: ["true", "false"]
allreduce_post_accumulation: ["true", "false"]

View File

@ -150,10 +150,11 @@ def MLM_eval():
res = net.eval(dataset, dataset_sink_mode=False)
print("==============================================================")
for _, v in res.items():
print("Accuracy is: ")
print(v)
print("Accuracy is: ", v)
print("==============================================================")
if __name__ == "__main__":
DEVICE_ID = 0
os.environ['DEVICE_ID'] = str(DEVICE_ID)
MLM_eval()

View File

@ -47,6 +47,8 @@ def parse_args():
parser.add_argument("--hccl_time_out", type=int, default=120,
help="Seconds to determine the hccl time out,"
"default: 120, which is the same as hccl default config")
parser.add_argument("--hccn_config_file", type=str, default="/etc/hccn.conf",
help="Path of the hccn.conf file.")
args = parser.parse_args()
return args
@ -128,7 +130,7 @@ def distribute_pretrain():
# get device_ips
device_ips = {}
physic_logic_ids = {}
with open('/etc/hccn.conf', 'r') as fin:
with open(args.hccn_config_file, 'r') as fin:
for hccn_item in fin.readlines():
if hccn_item.strip().startswith('address_'):
device_id, device_ip = hccn_item.split('=')
@ -136,7 +138,7 @@ def distribute_pretrain():
device_ips[device_id] = device_ip.strip()
if not device_ips:
raise ValueError("There is no address in /etc/hccn.conf")
raise ValueError("There is no address in hccn.conf file.")
for logic_id, device_id in enumerate(sorted(device_ips.keys())):
physic_logic_ids[device_id] = logic_id

View File

@ -2,12 +2,11 @@
distribute=true
epoch_size=40
enable_save_ckpt=true
enable_lossscale=true
do_shuffle=true
enable_data_sink=true
data_sink_steps=100
accumulation_steps=1
allreduce_post_accumulation=true
save_checkpoint_path=./
save_checkpoint_steps=10000
save_checkpoint_num=1
config_path=../../pretrain_config.yaml

View File

@ -30,6 +30,7 @@ python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_pretrain_cm
--data_dir=$1 \
--hccl_config_dir=$2 \
--hccl_time_out=600 \
--hccn_config_file='/etc/hccn.conf' \
--cmd_file=distributed_cmd.sh
bash distributed_cmd.sh

View File

@ -53,10 +53,10 @@ bert_net_cfg = BertConfig(
seq_length=512,
vocab_size=30522,
hidden_size=1024,
num_hidden_layers=4,
num_hidden_layers=6,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_act="fast_gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
@ -166,7 +166,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
lr = get_bert_thor_lr()
damping = get_bert_thor_damping()
split_indices = None
split_indices = [13, 37, 41]
optimizer = thor(net_with_loss, lr, damping, momentum, weight_decay, loss_scale, batch_size,
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
split_indices=split_indices, enable_clip_grad=True, frequency=frequency)
@ -233,7 +233,7 @@ def test_bert_thor_8p():
os.system("rm -rf " + str(i))
print("End training...")
assert mean_cost < 69
assert mean_cost < 96
assert mean_loss < 8.125