parse bert dtype config

This commit is contained in:
chenhaozhe 2021-07-14 18:32:51 +08:00
parent 1767f3acf7
commit 011eab9f4b
1 changed files with 18 additions and 6 deletions

View File

@ -112,6 +112,16 @@ def merge(args, cfg):
return cfg
def parse_dtype(dtype):
if dtype not in ["mstype.float32", "mstype.float16"]:
raise ValueError("Not supported dtype")
if dtype == "mstype.float32":
return mstype.float32
if dtype == "mstype.float16":
return mstype.float16
return None
def extra_operations(cfg):
"""
Do extra work on config
@ -125,12 +135,14 @@ def extra_operations(cfg):
if cfg.description == 'run_pretrain':
cfg.AdamWeightDecay.decay_filter = create_filter_fun(cfg.AdamWeightDecay.decay_filter)
cfg.Lamb.decay_filter = create_filter_fun(cfg.Lamb.decay_filter)
cfg.base_net_cfg.dtype = mstype.float32
cfg.base_net_cfg.compute_type = mstype.float16
cfg.nezha_net_cfg.dtype = mstype.float32
cfg.nezha_net_cfg.compute_type = mstype.float16
cfg.large_net_cfg.dtype = mstype.float32
cfg.large_net_cfg.compute_type = mstype.float16
cfg.base_net_cfg.dtype = parse_dtype(cfg.base_net_cfg.dtype)
cfg.base_net_cfg.compute_type = parse_dtype(cfg.base_net_cfg.compute_type)
cfg.nezha_net_cfg.dtype = parse_dtype(cfg.nezha_net_cfg.dtype)
cfg.nezha_net_cfg.compute_type = parse_dtype(cfg.nezha_net_cfg.compute_type)
cfg.large_net_cfg.dtype = parse_dtype(cfg.large_net_cfg.dtype)
cfg.large_net_cfg.compute_type = parse_dtype(cfg.large_net_cfg.compute_type)
cfg.large_acc_net_cfg.dtype = parse_dtype(cfg.large_acc_net_cfg.dtype)
cfg.large_acc_net_cfg.compute_type = parse_dtype(cfg.large_acc_net_cfg.compute_type)
if cfg.bert_network == 'base':
cfg.batch_size = cfg.base_batch_size
_bert_net_cfg = cfg.base_net_cfg